data_stardard.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. import os
  2. import json
  3. import asyncio
  4. import io
  5. import csv
  6. import datetime
  7. import httpx
  8. # --- LangChain Imports ---
  9. from langchain_openai import ChatOpenAI
  10. from langchain_core.prompts import ChatPromptTemplate
  11. from langchain_core.output_parsers import JsonOutputParser
  12. # --- 核心 Parser ---
  13. class TransactionParserAgent:
  14. def __init__(self, api_key: str, multimodal_api_url: str, base_url: str = "https://api.deepseek.com"):
  15. # 1. 初始化 LangChain ChatOpenAI 客户端
  16. # DeepSeek 完全兼容 OpenAI 接口,使用 ChatOpenAI 是标准做法
  17. self.llm = ChatOpenAI(
  18. model="deepseek-chat",
  19. api_key=api_key,
  20. base_url=base_url,
  21. temperature=0.1,
  22. max_retries=3, # LangChain 内置重试机制
  23. model_kwargs={
  24. "response_format": {"type": "json_object"} # 强制 JSON 模式
  25. },
  26. # 配置 httpx 客户端以优化超时和连接 (LangChain 允许透传 http_client)
  27. http_client=httpx.Client(
  28. timeout=httpx.Timeout(300.0, read=300.0, connect=60.0),
  29. limits=httpx.Limits(max_keepalive_connections=5, max_connections=10)
  30. )
  31. )
  32. self.multimodal_api_url = multimodal_api_url
  33. # 定义 JSON 解析器
  34. self.parser = JsonOutputParser()
  35. async def _invoke_miner_u(self, file_path: str) -> str:
  36. """调用 MinerU 并提取纯行数据 (保持 httpx 调用不变,因为这不是 LLM)"""
  37. print(f"🚀 MinerU 解析中: {os.path.basename(file_path)}")
  38. try:
  39. # MinerU 是独立服务,继续使用原生 httpx
  40. async with httpx.AsyncClient() as client:
  41. with open(file_path, 'rb') as f:
  42. files = {'file': (os.path.basename(file_path), f)}
  43. data = {'folderId': 'text'}
  44. response = await client.post(self.multimodal_api_url, files=files, data=data, timeout=120.0)
  45. if response.status_code == 200:
  46. res_json = response.json()
  47. full_md_list = []
  48. for element in res_json.get('convert_json', []):
  49. if element.get('type') == 'table' and 'md' in element:
  50. full_md_list.append(element['md'])
  51. return "\n\n".join(full_md_list)
  52. return ""
  53. except Exception as e:
  54. print(f"❌ MinerU 调用异常: {e}")
  55. return ""
  56. def _get_csv_prompt_template(self) -> ChatPromptTemplate:
  57. """
  58. 构造 LangChain 的 Prompt 模板
  59. """
  60. system_template = """
  61. # Role
  62. 你是一个高精度的银行账单转换工具。
  63. # Task
  64. 将输入的 Markdown 表格行转换为 JSON 数组。
  65. # Field Rules
  66. 1. txId: 如果输入数据中有交易流水号则直接使用,如果没有,从 T{start_id:04d} 开始递增生成。
  67. 2. txDate: 交易日期,格式为YYYY-MM-DD
  68. 3. txTime: 交易时间,格式为HH:mm:ss (未知填 00:00:00)
  69. 4. txAmount: 交易金额,绝对值数字
  70. 5. txBalance: 交易后余额。浮点数,移除千分位逗号。
  71. 6. txDirection: 交易方向。必须根据以下逻辑判断只输出“收入”或“支出”:
  72. - 若有“借/贷”列:“借”通常为支出,“贷”通常为收入(除非是信用卡,需结合表头)。
  73. - 若有“收入/支出”分列:按列归类。
  74. - 若金额带正负号:"+"为收入,"-"为支出。
  75. - 如果无符号,请结合表头判断。
  76. 7. txSummary: 摘要、用途、业务类型等备注。
  77. 8. txCounterparty: 交易对手方(名称及账号,如有)。
  78. # Constraints
  79. - **强制输出格式**:
  80. 1. 严格返回一个包含对象的 JSON 数组。
  81. 2. 每个对象必须包含上述 8 个字段名作为 Key。
  82. 3. 不要输出任何解释文字或 Markdown 代码块标签。
  83. """
  84. user_template = """# Input Data
  85. {chunk_data}
  86. # Output
  87. JSON Array:
  88. """
  89. return ChatPromptTemplate.from_messages([
  90. ("system", system_template),
  91. ("user", user_template)
  92. ])
  93. async def parse_to_csv(self, file_path: str) -> str:
  94. # 1. 获取完整 Markdown 文本并按行切分
  95. md_text = await self._invoke_miner_u(file_path)
  96. if not md_text:
  97. return ""
  98. # 初步切分
  99. raw_lines = md_text.splitlines()
  100. # 提取真正的第一行作为基准表头
  101. clean_lines = [l.strip() for l in raw_lines if l.strip()]
  102. if len(clean_lines) < 2: return ""
  103. table_header = clean_lines[0]
  104. data_rows = []
  105. for line in clean_lines[1:]:
  106. if all(c in '|- ' for c in line): continue
  107. if line == table_header: continue
  108. data_rows.append(line)
  109. csv_header = "txId,txDate,txTime,txAmount,txDirection,txBalance,txSummary,txCounterparty,createdAt\n"
  110. csv_content = csv_header
  111. batch_size = 15
  112. global_tx_counter = 1
  113. # 构建 LCEL Chain: Prompt -> LLM -> Parser
  114. chain = self._get_csv_prompt_template() | self.llm | self.parser
  115. # 2. 分块处理
  116. for i in range(0, len(data_rows), batch_size):
  117. chunk = data_rows[i: i + batch_size]
  118. context_chunk = [table_header] + chunk
  119. chunk_str = "\n".join(context_chunk)
  120. print(f"🔄 正在转换批次 {i // batch_size + 1},包含 {len(chunk)} 条数据...")
  121. print(f"待转换的数据块:\n{chunk_str}")
  122. try:
  123. # --- LangChain 调用 ---
  124. # 使用 ainvoke 异步调用链
  125. data_data = await chain.ainvoke({
  126. "start_id": global_tx_counter,
  127. "chunk_data": chunk_str
  128. })
  129. print(f"💡 LLM 返回数据: {data_data}")
  130. # 兼容处理:LangChain Parser 通常会直接返回 List 或 Dict
  131. if isinstance(data_data, dict):
  132. # 尝试寻找 transactions 键,如果没有则假设整个 dict 就是我们要的对象(虽然罕见)
  133. batch_data = data_data.get("transactions", [data_data])
  134. # 如果取出来还是 dict (例如单条记录),包一层 list
  135. if isinstance(batch_data, dict):
  136. batch_data = [batch_data]
  137. elif isinstance(data_data, list):
  138. batch_data = data_data
  139. else:
  140. batch_data = []
  141. if batch_data:
  142. output = io.StringIO()
  143. createdAtStr = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  144. writer = csv.writer(output, quoting=csv.QUOTE_ALL, lineterminator='\n')
  145. print(f"✅ 批次转换成功,包含 {len(batch_data)} 条记录。")
  146. for item in batch_data:
  147. writer.writerow([
  148. item.get("txId", ""),
  149. item.get("txDate", ""),
  150. item.get("txTime", ""),
  151. item.get("txAmount", ""),
  152. item.get("txDirection", ""),
  153. item.get("txBalance", ""),
  154. item.get("txSummary", ""),
  155. item.get("txCounterparty", ""),
  156. createdAtStr
  157. ])
  158. batch_csv_string = output.getvalue()
  159. csv_content += batch_csv_string
  160. global_tx_counter += len(batch_data)
  161. except Exception as e:
  162. print(f"⚠️ 批次执行失败: {e}")
  163. return csv_content
  164. async def parse_and_save_to_file(self, file_path: str, output_dir: str = "output") -> str:
  165. """
  166. 供 Workflow 调用:解析并保存文件,返回全路径名
  167. """
  168. current_script_path = os.path.abspath(__file__)
  169. current_dir = os.path.dirname(current_script_path)
  170. output_dir = os.path.normpath(os.path.join(current_dir, "..", output_dir))
  171. os.makedirs(output_dir, exist_ok=True)
  172. timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
  173. file_name = f"statement_{timestamp}.csv"
  174. full_path = os.path.join(output_dir, file_name)
  175. csv_result = await self.parse_to_csv(file_path)
  176. if csv_result:
  177. with open(full_path, "w", encoding="utf-8-sig") as f:
  178. f.write(csv_result)
  179. return full_path
  180. else:
  181. raise Exception("数据解析失败,未生成有效内容")
  182. async def run_workflow_task(self, input_file_path: str) -> dict:
  183. """
  184. 标准 Workflow 入口方法
  185. """
  186. try:
  187. print(f"传入文件路径:{input_file_path}")
  188. api_results_dir = "api_results"
  189. saved_path = await self.parse_and_save_to_file(input_file_path, api_results_dir)
  190. return {
  191. "status": "success",
  192. "file_path": saved_path,
  193. "file_name": os.path.basename(saved_path),
  194. "timestamp": datetime.datetime.now().isoformat()
  195. }
  196. except Exception as e:
  197. return {
  198. "status": "error",
  199. "message": str(e)
  200. }
  201. # --- 运行 ---
  202. async def main():
  203. agent = TransactionParserAgent(
  204. api_key="sk-8634dbc2866540c4b6003bb5733f23d8",
  205. multimodal_api_url="http://103.154.31.78:20012/api/file/read"
  206. )
  207. current_script_path = os.path.abspath(__file__)
  208. current_dir = os.path.dirname(current_script_path)
  209. # 模拟 Workflow 传入一个待处理文件
  210. input_pdf = "data_files/1.pdf"
  211. filepath = os.path.normpath(os.path.join(current_dir, "..", "..", input_pdf))
  212. if not os.path.exists(filepath):
  213. print(f"{filepath}文件不存在")
  214. return
  215. result = await agent.run_workflow_task(filepath)
  216. if result["status"] == "success":
  217. print(f"🎯 Workflow 任务完成!")
  218. print(f"📂 文件全路径: {result['file_path']}")
  219. else:
  220. print(f"❌ 任务失败: {result['message']}")
  221. if __name__ == "__main__":
  222. asyncio.run(main())