data_stardard.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import os
  2. import json
  3. import asyncio
  4. import io
  5. import csv
  6. import datetime
  7. import httpx
  8. # --- LangChain Imports ---
  9. from langchain_openai import ChatOpenAI
  10. from langchain_core.prompts import ChatPromptTemplate
  11. from langchain_core.output_parsers import JsonOutputParser
  12. # --- 核心 Parser ---
  13. class TransactionParserAgent:
  14. def __init__(self, api_key: str, multimodal_api_url: str, base_url: str = "https://api.deepseek.com"):
  15. # 1. 初始化 LangChain ChatOpenAI 客户端
  16. # DeepSeek 完全兼容 OpenAI 接口,使用 ChatOpenAI 是标准做法
  17. self.llm = ChatOpenAI(
  18. model="deepseek-chat",
  19. api_key=api_key,
  20. base_url=base_url,
  21. temperature=0.1,
  22. max_retries=3, # LangChain 内置重试机制
  23. model_kwargs={
  24. "response_format": {"type": "json_object"} # 强制 JSON 模式
  25. },
  26. # 配置 httpx 客户端以优化超时和连接 (LangChain 允许透传 http_client)
  27. http_client=httpx.Client(
  28. timeout=httpx.Timeout(300.0, read=300.0, connect=60.0),
  29. limits=httpx.Limits(max_keepalive_connections=5, max_connections=10)
  30. )
  31. )
  32. self.multimodal_api_url = multimodal_api_url
  33. # 定义 JSON 解析器
  34. self.parser = JsonOutputParser()
  35. async def _invoke_miner_u(self, file_path: str) -> str:
  36. """调用 MinerU 并提取纯行数据 (保持 httpx 调用不变,因为这不是 LLM)"""
  37. print(f"🚀 MinerU 解析中: {os.path.basename(file_path)}")
  38. try:
  39. # MinerU 是独立服务,继续使用原生 httpx
  40. async with httpx.AsyncClient() as client:
  41. with open(file_path, 'rb') as f:
  42. files = {'file': (os.path.basename(file_path), f)}
  43. data = {'folderId': 'text'}
  44. response = await client.post(self.multimodal_api_url, files=files, data=data, timeout=120.0)
  45. if response.status_code == 200:
  46. res_json = response.json()
  47. full_md_list = []
  48. for element in res_json.get('convert_json', []):
  49. if 'md' in element:
  50. full_md_list.append(element['md'])
  51. return "\n\n".join(full_md_list)
  52. return ""
  53. except Exception as e:
  54. print(f"❌ MinerU 调用异常: {e}")
  55. return ""
  56. def _get_csv_prompt_template(self) -> ChatPromptTemplate:
  57. """
  58. 构造 LangChain 的 Prompt 模板
  59. """
  60. system_template = """
  61. # Role
  62. 你是一个高精度的银行账单转换工具。
  63. # Task
  64. 将输入的 Markdown 表格行转换为 JSON 数组。
  65. # Field Rules
  66. 1. txId: 如果输入数据中有交易流水号则直接使用,如果没有,从 T{start_id:04d} 开始递增生成。
  67. 2. txDate: 交易日期,格式为YYYY-MM-DD
  68. 3. txTime: 交易时间,格式为HH:mm:ss (未知填 00:00:00)
  69. 4. txAmount: 交易金额,绝对值数字
  70. 5. txBalance: 交易后余额。浮点数,移除千分位逗号。
  71. 6. txDirection: 交易方向。必须根据以下逻辑判断只输出“收入”或“支出”:
  72. - 若有“借/贷”列:“借”通常为支出,“贷”通常为收入(除非是信用卡,需结合表头)。
  73. - 若有“收入/支出”分列:按列归类。
  74. - 若金额带正负号:"+"为收入,"-"为支出。
  75. - 如果无符号,请结合表头判断。
  76. 7. txSummary: 摘要、用途、业务类型等备注。
  77. 8. txCounterparty: 交易对手方(名称及账号,如有)。
  78. # Constraints
  79. - **强制输出格式**:
  80. 1. 严格返回一个包含对象的 JSON 数组。
  81. 2. 每个对象必须包含上述 8 个字段名作为 Key。
  82. 3. 不要输出任何解释文字或 Markdown 代码块标签。
  83. """
  84. user_template = """# Input Data
  85. {chunk_data}
  86. # Output
  87. JSON Array:
  88. """
  89. return ChatPromptTemplate.from_messages([
  90. ("system", system_template),
  91. ("user", user_template)
  92. ])
  93. async def parse_to_csv(self, file_path: str) -> str:
  94. # 1. 获取完整 Markdown 文本并按行切分
  95. md_text = await self._invoke_miner_u(file_path)
  96. if not md_text:
  97. return ""
  98. # 初步切分
  99. raw_lines = md_text.splitlines()
  100. # 提取真正的第一行作为基准表头
  101. clean_lines = [l.strip() for l in raw_lines if l.strip()]
  102. if len(clean_lines) < 2: return ""
  103. # --- 【核心改进:动态寻找表头】 ---
  104. table_header = ""
  105. header_index = 0
  106. header_keywords = ["余额", "金额", "账号", "日期", "借/贷", "摘要"]
  107. for idx, line in enumerate(clean_lines):
  108. # 如果某一行包含 2 个以上关键词,且含有 Markdown 表格分隔符 '|'
  109. hit_count = sum(1 for kw in header_keywords if kw in line)
  110. if hit_count >= 2 and "|" in line:
  111. table_header = line
  112. header_index = idx
  113. break
  114. if not table_header:
  115. table_header = clean_lines[0]
  116. header_index = 0
  117. data_rows = []
  118. for line in clean_lines[header_index + 1:]:
  119. if all(c in '|- ' for c in line): continue
  120. if line == table_header: continue
  121. # 过滤掉一些 MinerU 可能在表格末尾产生的页码或无关文字
  122. if "|" not in line: continue
  123. data_rows.append(line)
  124. csv_header = "txId,txDate,txTime,txAmount,txDirection,txBalance,txSummary,txCounterparty,createdAt\n"
  125. csv_content = csv_header
  126. batch_size = 15
  127. global_tx_counter = 1
  128. # 构建 LCEL Chain: Prompt -> LLM -> Parser
  129. chain = self._get_csv_prompt_template() | self.llm | self.parser
  130. # 2. 分块处理
  131. for i in range(0, len(data_rows), batch_size):
  132. chunk = data_rows[i: i + batch_size]
  133. context_chunk = [table_header] + chunk
  134. chunk_str = "\n".join(context_chunk)
  135. print(f"🔄 正在转换批次 {i // batch_size + 1},包含 {len(chunk)} 条数据...")
  136. # print(f"待转换的数据块:\n{chunk_str}")
  137. try:
  138. # --- LangChain 调用 ---
  139. # 使用 ainvoke 异步调用链
  140. data_data = await chain.ainvoke({
  141. "start_id": global_tx_counter,
  142. "chunk_data": chunk_str
  143. })
  144. # print(f"💡 LLM 返回数据: {data_data}")
  145. # 兼容处理:LangChain Parser 通常会直接返回 List 或 Dict
  146. if isinstance(data_data, dict):
  147. # 尝试寻找 transactions 键,如果没有则假设整个 dict 就是我们要的对象(虽然罕见)
  148. batch_data = data_data.get("transactions", [data_data])
  149. # 如果取出来还是 dict (例如单条记录),包一层 list
  150. if isinstance(batch_data, dict):
  151. batch_data = [batch_data]
  152. elif isinstance(data_data, list):
  153. batch_data = data_data
  154. else:
  155. batch_data = []
  156. if batch_data:
  157. output = io.StringIO()
  158. createdAtStr = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  159. writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL, lineterminator='\n')
  160. print(f"✅ 批次转换成功,包含 {len(batch_data)} 条记录。")
  161. for item in batch_data:
  162. writer.writerow([
  163. item.get("txId", ""),
  164. item.get("txDate", ""),
  165. item.get("txTime", ""),
  166. item.get("txAmount", ""),
  167. item.get("txDirection", ""),
  168. item.get("txBalance", ""),
  169. item.get("txSummary", ""),
  170. item.get("txCounterparty", ""),
  171. createdAtStr
  172. ])
  173. batch_csv_string = output.getvalue()
  174. csv_content += batch_csv_string
  175. global_tx_counter += len(batch_data)
  176. except Exception as e:
  177. print(f"⚠️ 批次执行失败: {e}")
  178. return csv_content
  179. async def parse_and_save_to_file(self, file_path: str, output_dir: str = "output") -> str:
  180. """
  181. 供 Workflow 调用:解析并保存文件,返回全路径名
  182. """
  183. current_script_path = os.path.abspath(__file__)
  184. current_dir = os.path.dirname(current_script_path)
  185. file_full_name = os.path.basename(file_path)
  186. file_name = os.path.splitext(file_full_name)[0] # 不带后缀 11111
  187. output_dir = os.path.normpath(os.path.join(current_dir, "..", "..", output_dir))
  188. os.makedirs(output_dir, exist_ok=True)
  189. timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
  190. file_name = f"{file_name}_data_standard_{timestamp}.csv"
  191. full_path = os.path.join(output_dir, file_name)
  192. csv_result = await self.parse_to_csv(file_path)
  193. if csv_result:
  194. with open(full_path, "w", encoding="utf-8") as f:
  195. f.write(csv_result)
  196. return full_path
  197. else:
  198. raise Exception("数据解析失败,未生成有效内容")
  199. async def run_workflow_task(self, input_file_path: str) -> dict:
  200. """
  201. 标准 Workflow 入口方法
  202. """
  203. try:
  204. print(f"待执行标准化的文件:{input_file_path}")
  205. api_results_dir = "data_files"
  206. saved_path = await self.parse_and_save_to_file(input_file_path, api_results_dir)
  207. return {
  208. "status": "success",
  209. "file_path": saved_path,
  210. "file_name": os.path.basename(saved_path),
  211. "timestamp": datetime.datetime.now().isoformat()
  212. }
  213. except Exception as e:
  214. return {
  215. "status": "error",
  216. "message": str(e)
  217. }
  218. # --- 运行 ---
  219. async def main():
  220. agent = TransactionParserAgent(
  221. api_key="sk-8634dbc2866540c4b6003bb5733f23d8",
  222. multimodal_api_url="http://103.154.31.78:20012/api/file/read"
  223. )
  224. current_script_path = os.path.abspath(__file__)
  225. current_dir = os.path.dirname(current_script_path)
  226. # 模拟 Workflow 传入一个待处理文件
  227. input_pdf = "data_files/11111.png"
  228. filepath = os.path.normpath(os.path.join(current_dir, "..", "..", input_pdf))
  229. if not os.path.exists(filepath):
  230. print(f"{filepath}文件不存在")
  231. return
  232. result = await agent.run_workflow_task(filepath)
  233. if result["status"] == "success":
  234. print(f"🎯 【数据标准化】任务完成!")
  235. print(f"📂 标准化后文件输出位置: {result['file_path']}")
  236. else:
  237. print(f"❌ 任务失败: {result['message']}")
  238. if __name__ == "__main__":
  239. asyncio.run(main())