|
|
@@ -6,17 +6,18 @@ import csv
|
|
|
import datetime
|
|
|
import httpx
|
|
|
import json
|
|
|
-import uuid
|
|
|
+import sqlite3
|
|
|
+import re
|
|
|
|
|
|
# --- LangChain Imports ---
|
|
|
from langchain_openai import ChatOpenAI
|
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
|
from langchain_core.output_parsers import JsonOutputParser
|
|
|
from langchain_core.outputs import Generation
|
|
|
-import re
|
|
|
|
|
|
-class SafeJsonOutputParser(JsonOutputParser):
|
|
|
|
|
|
+# --- 保持工具类不变 ---
|
|
|
+class SafeJsonOutputParser(JsonOutputParser):
|
|
|
def parse_result(self, result, *, partial: bool = False):
|
|
|
if isinstance(result, list) and len(result) > 0:
|
|
|
generation = result[0]
|
|
|
@@ -25,31 +26,28 @@ class SafeJsonOutputParser(JsonOutputParser):
|
|
|
else:
|
|
|
raise ValueError(f"Unexpected result type: {type(result)}")
|
|
|
text = generation.text
|
|
|
- # 1️⃣ 去 <think>...</think>
|
|
|
text = re.sub(r"<think>.*?</think>", "", text, flags=re.S).strip()
|
|
|
-
|
|
|
- # 2️⃣ 去 ```json ``` 包裹
|
|
|
text = re.sub(r"^```(?:json)?|```$", "", text, flags=re.I | re.M).strip()
|
|
|
-
|
|
|
- # 3️⃣ ⭐ 只截取 JSON 本体
|
|
|
match = re.search(r"(\[\s*{.*}\s*\]|\{\s*\".*\"\s*\})", text, flags=re.S)
|
|
|
if not match:
|
|
|
- raise ValueError(f"Invalid json output after clean: {text[:200]}")
|
|
|
-
|
|
|
+ # 兼容:有时候 LLM 可能直接返回 SQL 字符串而不是 JSON,这里做个简单的容错
|
|
|
+ if "SELECT" in text.upper():
|
|
|
+ return {"sql": text}
|
|
|
+ raise ValueError(f"Invalid json output: {text[:200]}")
|
|
|
json_text = match.group(1)
|
|
|
return json.loads(json_text)
|
|
|
|
|
|
|
|
|
-# --- 核心 Parser ---
|
|
|
class TransactionParserAgent:
|
|
|
def __init__(self, api_key: str, multimodal_api_url: str, base_url: str = "https://api.deepseek.com", model_name: str = "deepseek-chat"):
|
|
|
# 1. 初始化 LangChain ChatOpenAI 客户端
|
|
|
# DeepSeek 完全兼容 OpenAI 接口,使用 ChatOpenAI 是标准做法
|
|
|
+ print(f"当前使用模型:{model_name}")
|
|
|
self.llm = ChatOpenAI(
|
|
|
model=model_name,
|
|
|
api_key=api_key,
|
|
|
base_url=base_url,
|
|
|
- temperature=0.1,
|
|
|
+ temperature=0.0,
|
|
|
max_retries=3, # LangChain 内置重试机制
|
|
|
# 配置 httpx 客户端以优化超时和连接 (LangChain 允许透传 http_client)
|
|
|
http_client=httpx.Client(
|
|
|
@@ -58,10 +56,7 @@ class TransactionParserAgent:
|
|
|
)
|
|
|
)
|
|
|
self.multimodal_api_url = multimodal_api_url
|
|
|
-
|
|
|
- # 定义 JSON 解析器
|
|
|
self.parser = SafeJsonOutputParser()
|
|
|
-
|
|
|
# 初始化API调用跟踪
|
|
|
self.api_calls = []
|
|
|
|
|
|
@@ -187,8 +182,8 @@ class TransactionParserAgent:
|
|
|
if 'md' in element:
|
|
|
full_md_list.append(element['md'])
|
|
|
if 'rows' in element:
|
|
|
- dealRows+=len(element['rows'])
|
|
|
- print(f"📊 提取结果:共提取 {dealRows-1} 条数据")
|
|
|
+ dealRows += len(element['rows'])
|
|
|
+ print(f"📊 提取结果:共提取 {dealRows - 1} 条数据")
|
|
|
return "\n\n".join(full_md_list)
|
|
|
return ""
|
|
|
except Exception as e:
|
|
|
@@ -196,251 +191,330 @@ class TransactionParserAgent:
|
|
|
return ""
|
|
|
finally:
|
|
|
print(f"✅ 【步骤1 - 数据提取】 执行完成")
|
|
|
- print(f"⏱️ 执行耗时:{ time.perf_counter() - miner_start_time:.2f} 秒")
|
|
|
- def _get_csv_prompt_template(self) -> ChatPromptTemplate:
|
|
|
+ print(f"⏱️ 执行耗时:{time.perf_counter() - miner_start_time:.2f} 秒")
|
|
|
+
|
|
|
+ # --- 🆕 核心逻辑:SQLite 转换引擎 ---
|
|
|
+ def _init_sqlite_db(self, data_rows: list, header_line: str, delimiter='|') -> tuple:
|
|
|
"""
|
|
|
- 构造 LangChain 的 Prompt 模板
|
|
|
+ 将 Markdown 行数据灌入 SQLite 内存数据库的通用宽表
|
|
|
+ 返回: (conn, header_mapping_info)
|
|
|
"""
|
|
|
+ # 1. 创建内存数据库
|
|
|
+ conn = sqlite3.connect(":memory:")
|
|
|
+ cursor = conn.cursor()
|
|
|
+
|
|
|
+ header_fingerprint = "".join(header_line.strip().strip('|').split())
|
|
|
+ header_added = False # 确保数据库里只进一个表头
|
|
|
+
|
|
|
+ # 2. 分析最大列数,建立通用宽表 (row_id, c0, c1, ... c30)
|
|
|
+ max_cols = 0
|
|
|
+ parsed_rows = []
|
|
|
+ # 预处理:清洗 Markdown 分隔符
|
|
|
+ for row in data_rows:
|
|
|
+ # 去除首尾的 |
|
|
|
+ clean_row = row.strip().strip('|')
|
|
|
+
|
|
|
+ # A. 过滤掉纯分割线(如 | --- | --- |)
|
|
|
+ if not re.search(r'[\u4e00-\u9fa5a-zA-Z0-9]', clean_row):
|
|
|
+ continue
|
|
|
+
|
|
|
+ # B. 提取当前行的指纹
|
|
|
+ current_fingerprint = "".join(clean_row.split())
|
|
|
+
|
|
|
+ # C. 核心判断:
|
|
|
+ if current_fingerprint == header_fingerprint:
|
|
|
+ if not header_added:
|
|
|
+ # 只有第一次见到表头指纹时,才放入数据库
|
|
|
+ header_added = True
|
|
|
+ else:
|
|
|
+ # 之后再见到一模一样的表头,直接跳过
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 分割
|
|
|
+ parts = [p.strip() for p in clean_row.split(delimiter)]
|
|
|
+ if len(parts) > max_cols:
|
|
|
+ max_cols = len(parts)
|
|
|
+ parsed_rows.append(parts)
|
|
|
+
|
|
|
+ if max_cols == 0:
|
|
|
+ return None, None
|
|
|
+
|
|
|
+ # 动态建表语句
|
|
|
+ cols_def = ", ".join([f"c{i} TEXT" for i in range(max_cols)])
|
|
|
+ create_sql = f"CREATE TABLE temp_raw_data (row_id INTEGER PRIMARY KEY AUTOINCREMENT, {cols_def});"
|
|
|
+ cursor.execute(create_sql)
|
|
|
+
|
|
|
+ # 3. 批量插入数据
|
|
|
+ insert_sql = f"INSERT INTO temp_raw_data ({', '.join([f'c{i}' for i in range(max_cols)])}) VALUES ({', '.join(['?' for _ in range(max_cols)])})"
|
|
|
+
|
|
|
+ # 补全数据(如果某行比最长行短,补None)
|
|
|
+ final_data = []
|
|
|
+ for p in parsed_rows:
|
|
|
+ padding = [None] * (max_cols - len(p))
|
|
|
+ final_data.append(p + padding)
|
|
|
+ cursor.executemany(insert_sql, final_data)
|
|
|
+ conn.commit()
|
|
|
+
|
|
|
+ return conn, max_cols
|
|
|
+
|
|
|
+ def _get_sql_generation_prompt(self) -> ChatPromptTemplate:
|
|
|
system_template = """
|
|
|
# Role
|
|
|
-你是一个高精度的银行账单转换工具。
|
|
|
+你是一个 SQLite 专家。
|
|
|
|
|
|
# Task
|
|
|
-将输入的 Markdown 表格行转换为 JSON 数组。
|
|
|
-
|
|
|
-# Field Rules
|
|
|
-1. txId: 如果输入数据中有交易流水号则直接使用,如果没有,从 T{start_id:04d} 开始递增生成。
|
|
|
-2. txDate: 交易日期,格式为YYYY-MM-DD
|
|
|
-3. txTime: 交易时间,格式为HH:mm:ss (未知填 00:00:00)
|
|
|
-4. txAmount: 交易金额,绝对值数字
|
|
|
-5. txBalance: 交易后余额。浮点数,移除千分位逗号。
|
|
|
-6. txDirection: 交易方向。必须根据以下逻辑判断只输出“收入”或“支出”:
|
|
|
- - 若有“借/贷”列:“借”通常为支出,“贷”通常为收入(除非是信用卡,需结合表头)。
|
|
|
- - 若有“收入/支出”分列:按列归类。
|
|
|
- - 若金额带正负号:"+"为收入,"-"为支出。
|
|
|
- - 如果无符号,请结合表头判断。
|
|
|
-7. txSummary: 摘要、用途、业务类型等备注。
|
|
|
-8. txCounterparty: 交易对手方(名称及账号,如有)。
|
|
|
-
|
|
|
-# Constraints
|
|
|
-- **强制输出格式**:
|
|
|
- 1. 严格返回一个包含对象的 JSON 数组。
|
|
|
- 2. 每个对象必须包含上述 8 个字段名作为 Key。
|
|
|
- 3. 不要输出任何解释文字或 Markdown 代码块标签。
|
|
|
-
|
|
|
-# Anti-Hallucination Rules
|
|
|
-- 不得根据上下文推断任何未在原始数据中明确出现的字段
|
|
|
-- 不得计算或猜测余额
|
|
|
-- 不得根据常识补全对手方名称
|
|
|
-- 若字段缺失,必须返回空字符串 ""
|
|
|
-"""
|
|
|
- user_template = """# Input Data
|
|
|
-{chunk_data}
|
|
|
-
|
|
|
-# Output
|
|
|
-JSON Array:
|
|
|
-"""
|
|
|
- return ChatPromptTemplate.from_messages([
|
|
|
- ("system", system_template),
|
|
|
- ("user", user_template)
|
|
|
- ])
|
|
|
+你有一个名为 `temp_raw_data` 的表,里面存储了 OCR 识别后的原始数据。
|
|
|
+表的列名为 `c0`, `c1`, `c2`... `cN`。
|
|
|
+请根据提供的【表头】和【数据样本】,编写一条 SQL 查询语句,将原始列映射为标准输出字段。
|
|
|
+
|
|
|
+# Target Schema (Output Columns)
|
|
|
+你的 SQL 必须 `SELECT` 出以下字段(顺序不能变):
|
|
|
+1. `txId`: 交易流水号。如果原始数据没有,使用 `row_id`。
|
|
|
+2. `txDate`: 交易日期 (格式 YYYY-MM-DD)。
|
|
|
+3. `txTime`: 交易时间 (格式 HH:mm:ss)。如果没有则返回 '00:00:00'。
|
|
|
+4. `txAmount`: 交易金额 (绝对值数字,**必须去除逗号**,转为 REAL/FLOAT)。
|
|
|
+5. `txDirection`: 交易方向 (必须经过逻辑判断输出 '收入' 或 '支出')。
|
|
|
+6. `txBalance`: 余额 (去除逗号)。
|
|
|
+7. `txSummary`: 摘要/用途。
|
|
|
+8. `txCounterparty`: 对方账号/户名。
|
|
|
+
|
|
|
+# Logic Rules (Crucial!)
|
|
|
+1. **Direction Logic**:
|
|
|
+ - 如果有单独的借/贷列:通常 "借"=`支出`, "贷"=`收入`。
|
|
|
+ - 如果有单独的收入/支出列:哪一列有值就是哪个方向。
|
|
|
+ - 如果金额有正负号:负号通常是支出。
|
|
|
+ - 请使用 SQL 的 `CASE WHEN ... THEN ... ELSE ... END` 语法处理。
|
|
|
+2. **Data Cleaning**:
|
|
|
+ - 金额字段必须处理千分位逗号:`CAST(REPLACE(c?, ',', '') AS REAL)`
|
|
|
+ - 日期必须清洗。
|
|
|
+
|
|
|
+# Output JSON Format
|
|
|
+```json
|
|
|
+{{
|
|
|
+ "sql": "SELECT ... FROM temp_raw_data WHERE ..."
|
|
|
+}}
|
|
|
+ """
|
|
|
+ user_template = """
|
|
|
+
|
|
|
+ # Table Info
|
|
|
+ Max Columns: {max_cols} Generic Column Names: c0, c1, ... c{max_cols_minus_1}
|
|
|
+
|
|
|
+ # Data Preview (Header + First 3 Rows)
|
|
|
+ {data_preview}
|
|
|
+
|
|
|
+ # Instruction
|
|
|
+ 请编写 SQL 语句来提取并清洗数据。 注意:不要包含 Markdown 的 sql 标签,直接返回 JSON。 忽略表头行(通常 row_id = 1 是表头,所以 WHERE row_id > 1)。 """
|
|
|
+ return ChatPromptTemplate.from_messages([("system", system_template), ("user", user_template)])
|
|
|
+
|
|
|
+ async def _generate_transform_sql(self, header_row: str, sample_rows: list, max_cols: int) -> str:
|
|
|
+ """让 LLM 编写 SQL"""
|
|
|
+ # 构建预览数据,带上 c0, c1 这种列名提示,方便 LLM 对应
|
|
|
+ preview_text = ""
|
|
|
+
|
|
|
+ # 表头预览
|
|
|
+ header_parts = [p.strip() for p in header_row.strip().strip('|').split('|')]
|
|
|
+ header_map = " | ".join([f"c{i}({val})" for i, val in enumerate(header_parts)])
|
|
|
+ preview_text += f"Mapping Hint: {header_map}\n"
|
|
|
+ preview_text += "-" * 50 + "\n"
|
|
|
+
|
|
|
+ # 数据预览
|
|
|
+ for row in sample_rows:
|
|
|
+ preview_text += row + "\n"
|
|
|
+
|
|
|
+ prompt_params = {
|
|
|
+ "max_cols": max_cols,
|
|
|
+ "max_cols_minus_1": max_cols - 1,
|
|
|
+ "data_preview": preview_text
|
|
|
+ }
|
|
|
+ # 记录API调用开始时间
|
|
|
+ call_start_time = datetime.datetime.now()
|
|
|
+
|
|
|
+ chain = self._get_sql_generation_prompt() | self.llm | self.parser
|
|
|
+
|
|
|
+ print(f"🧠 [LLM] 正在生成 SQL 清洗逻辑...")
|
|
|
+ result = ""
|
|
|
+ try:
|
|
|
+ result = await chain.ainvoke(prompt_params)
|
|
|
+ sql = result.get("sql")
|
|
|
+ print(f"💡 [LLM] 生成 SQL:\n{sql}")
|
|
|
+ return sql
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ SQL 生成失败: {e}")
|
|
|
+ return ""
|
|
|
+ finally:
|
|
|
+ # 记录API调用结束时间
|
|
|
+ call_end_time = datetime.datetime.now()
|
|
|
+ # 记录API调用结果 - 简化版:只保存提示词和结果数据
|
|
|
+ call_id = f"api_llm_数据转换_{'{:.2f}'.format((call_end_time - call_start_time).total_seconds())}"
|
|
|
+ # 从chain中提取提示词(如果可能)
|
|
|
+ prompt_content = ""
|
|
|
+ try:
|
|
|
+ # 尝试从chain获取最后的消息内容
|
|
|
+ if hasattr(chain, 'get_prompts'):
|
|
|
+ prompts = chain.get_prompts()
|
|
|
+ if prompts:
|
|
|
+ prompt_content = str(prompts[-1])
|
|
|
+ else:
|
|
|
+ # 如果无法获取,构造基本的提示词信息
|
|
|
+ prompt_content = f"传入数据,max_cols: {max_cols},preview_text: {preview_text}..."
|
|
|
+ except:
|
|
|
+ prompt_content = f"传入数据,max_cols: {max_cols},preview_text: {preview_text}..."
|
|
|
+
|
|
|
+ api_call_info = {
|
|
|
+ "call_id": call_id,
|
|
|
+ "start_time": call_start_time.isoformat(),
|
|
|
+ "end_time": call_end_time.isoformat(),
|
|
|
+ "duration": (call_end_time - call_start_time).total_seconds(),
|
|
|
+ "prompt": prompt_content,
|
|
|
+ "input_params": {
|
|
|
+ "max_cols": max_cols,
|
|
|
+ "max_cols_minus_1": max_cols - 1,
|
|
|
+ "data_preview": preview_text
|
|
|
+ },
|
|
|
+ "llm_result": result
|
|
|
+ }
|
|
|
+ self.api_calls.append(api_call_info)
|
|
|
+
|
|
|
+ # 保存API结果到文件 (Markdown格式,更易阅读)
|
|
|
+ # 使用运行ID创建独立的文件夹
|
|
|
+ run_id = os.environ.get('FLOW_RUN_ID', 'default')
|
|
|
+ api_results_dir = f"api_results_{run_id}"
|
|
|
+ os.makedirs(api_results_dir, exist_ok=True)
|
|
|
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
+ filename = f"{timestamp}_{call_id}.md"
|
|
|
+ filepath = os.path.join(api_results_dir, filename)
|
|
|
+ try:
|
|
|
+ with open(filepath, 'w', encoding='utf-8') as f:
|
|
|
+ f.write("# 数据转换结果\n\n")
|
|
|
+ f.write("## 调用信息\n\n")
|
|
|
+ f.write(f"- 调用ID: {call_id}\n")
|
|
|
+ f.write(f"- 开始时间: {call_start_time.isoformat()}\n")
|
|
|
+ f.write(f"- 结束时间: {call_end_time.isoformat()}\n")
|
|
|
+ f.write(f"- 执行时长: {(call_end_time - call_start_time).total_seconds():.2f} 秒\n")
|
|
|
+ f.write("\n## 提示词入参\n\n")
|
|
|
+ f.write("```\n")
|
|
|
+ f.write(api_call_info["prompt"])
|
|
|
+ f.write("\n```\n\n")
|
|
|
+ f.write("## 输入参数\n\n")
|
|
|
+ f.write("```json\n")
|
|
|
+ f.write(json.dumps(api_call_info["input_params"], ensure_ascii=False, indent=2))
|
|
|
+ f.write("\n```\n\n")
|
|
|
+ f.write("## LLM返回结果\n\n")
|
|
|
+ f.write("```json\n")
|
|
|
+ f.write(json.dumps(api_call_info["llm_result"], ensure_ascii=False, indent=2))
|
|
|
+ f.write("\n```\n")
|
|
|
+ print(f"[API_RESULT] 保存API结果文件: {filepath}")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[ERROR] 保存API结果文件失败: {filepath}, 错误: {str(e)}")
|
|
|
|
|
|
async def parse_to_csv(self, file_path: str) -> str:
|
|
|
- # 1. 获取完整 Markdown 文本并按行切分
|
|
|
+ # 1. 获取 Markdown
|
|
|
md_text = await self._invoke_miner_u(file_path)
|
|
|
- if not md_text:
|
|
|
- return ""
|
|
|
+ if not md_text: return ""
|
|
|
+
|
|
|
# 记录开始时间(使用time.perf_counter获取高精度时间)
|
|
|
- switch_start_time = time.perf_counter()
|
|
|
+ start_time = time.perf_counter()
|
|
|
print("\n" + "=" * 40)
|
|
|
print("📌 【步骤2 - 标准化转换】 开始执行")
|
|
|
- # 初步切分
|
|
|
+
|
|
|
+ # 2. 预处理数据行
|
|
|
raw_lines = md_text.splitlines()
|
|
|
- # 提取真正的第一行作为基准表头
|
|
|
- clean_lines = [l.strip() for l in raw_lines if l.strip()]
|
|
|
- if len(clean_lines) < 2: return ""
|
|
|
+ clean_lines = [l.strip() for l in raw_lines if l.strip() and "|" in l]
|
|
|
+ # 简单判定表头 (包含2个以上关键词)
|
|
|
+ header_line = ""
|
|
|
+ header_idx = 0
|
|
|
+ keywords = ["日期", "金额", "余额", "摘要", "用途", "借", "贷"]
|
|
|
+ for idx, line in enumerate(clean_lines):
|
|
|
+ if sum(1 for k in keywords if k in line) >= 2:
|
|
|
+ header_line = line
|
|
|
+ header_idx = idx
|
|
|
+ break
|
|
|
|
|
|
- # --- 【核心改进:动态寻找表头】 ---
|
|
|
- table_header = ""
|
|
|
- header_index = 0
|
|
|
+ if not header_line:
|
|
|
+ header_line = clean_lines[0]
|
|
|
|
|
|
- header_keywords = ["余额", "金额", "账号", "日期", "借/贷", "摘要"]
|
|
|
+ # 数据行 (保留原始数据,之后灌入 DB)
|
|
|
+ data_rows = clean_lines # 把表头也灌进去,通过 row_id > header_idx + 1 来过滤
|
|
|
|
|
|
- for idx, line in enumerate(clean_lines):
|
|
|
- # 如果某一行包含 2 个以上关键词,且含有 Markdown 表格分隔符 '|'
|
|
|
- hit_count = sum(1 for kw in header_keywords if kw in line)
|
|
|
- if hit_count >= 2 and "|" in line:
|
|
|
- table_header = line
|
|
|
- header_index = idx
|
|
|
- break
|
|
|
+ # 3. 灌入 SQLite
|
|
|
+ conn, max_cols = self._init_sqlite_db(data_rows,header_line)
|
|
|
+ if not conn:
|
|
|
+ return ""
|
|
|
|
|
|
- if not table_header:
|
|
|
- table_header = clean_lines[0]
|
|
|
- header_index = 0
|
|
|
- data_rows = []
|
|
|
- for line in clean_lines[header_index + 1:]:
|
|
|
- if all(c in '|- ' for c in line): continue
|
|
|
- if line == table_header: continue
|
|
|
- # 过滤掉一些 MinerU 可能在表格末尾产生的页码或无关文字
|
|
|
- if "|" not in line: continue
|
|
|
- data_rows.append(line)
|
|
|
-
|
|
|
- csv_header = "txId,txDate,txTime,txAmount,txDirection,txBalance,txSummary,txCounterparty,createdAt\n"
|
|
|
- csv_content = csv_header
|
|
|
-
|
|
|
- batch_size = 15
|
|
|
- global_tx_counter = 1
|
|
|
-
|
|
|
- # 构建 LCEL Chain: Prompt -> LLM -> Parser
|
|
|
- chain = self._get_csv_prompt_template() | self.llm | self.parser
|
|
|
-
|
|
|
- # 2. 分块处理
|
|
|
- for i in range(0, len(data_rows), batch_size):
|
|
|
- chunk = data_rows[i: i + batch_size]
|
|
|
- context_chunk = [table_header] + chunk
|
|
|
- chunk_str = "\n".join(context_chunk)
|
|
|
- # 1. 记录开始时间(使用time.perf_counter获取高精度时间)
|
|
|
- start_time = time.perf_counter()
|
|
|
- print(f"🔄 正在通过LLM转换批次 {i // batch_size + 1},包含 {len(chunk)} 条数据...")
|
|
|
- # print(f"待转换的数据块:\n{chunk_str}")
|
|
|
- try:
|
|
|
- # --- LangChain 调用 ---
|
|
|
- # 使用 ainvoke 异步调用链
|
|
|
- # 记录API调用开始时间
|
|
|
- call_start_time = datetime.datetime.now()
|
|
|
+ try:
|
|
|
+ # 4. LLM 生成 SQL
|
|
|
+ # 取表头和前3条数据作为样本
|
|
|
+ sample_data = clean_lines[header_idx:header_idx + 4]
|
|
|
+ sql_query = await self._generate_transform_sql(header_line, sample_data, max_cols)
|
|
|
|
|
|
- data_data = await chain.ainvoke({
|
|
|
- "start_id": global_tx_counter,
|
|
|
- "chunk_data": chunk_str
|
|
|
- })
|
|
|
+ if not sql_query:
|
|
|
+ return ""
|
|
|
|
|
|
- # 记录API调用结束时间
|
|
|
- call_end_time = datetime.datetime.now()
|
|
|
+ # 5. 执行 SQL
|
|
|
+ cursor = conn.cursor()
|
|
|
|
|
|
- # 记录API调用结果 - 简化版:只保存提示词和结果数据
|
|
|
- call_id = f"api_llm_数据转换_{'{:.2f}'.format((call_end_time - call_start_time).total_seconds())}"
|
|
|
+ # 为了安全,确保 SQL 只是 SELECT
|
|
|
+ if "DROP" in sql_query.upper() or "DELETE" in sql_query.upper():
|
|
|
+ raise ValueError("Unsafe SQL detected")
|
|
|
|
|
|
- # 从chain中提取提示词(如果可能)
|
|
|
- prompt_content = ""
|
|
|
- try:
|
|
|
- # 尝试从chain获取最后的消息内容
|
|
|
- if hasattr(chain, 'get_prompts'):
|
|
|
- prompts = chain.get_prompts()
|
|
|
- if prompts:
|
|
|
- prompt_content = str(prompts[-1])
|
|
|
- else:
|
|
|
- # 如果无法获取,构造基本的提示词信息
|
|
|
- prompt_content = f"转换批次数据,start_id: {global_tx_counter}, chunk_data: {chunk_str[:200]}..."
|
|
|
- except:
|
|
|
- prompt_content = f"转换批次数据,start_id: {global_tx_counter}, chunk_data: {chunk_str[:200]}..."
|
|
|
-
|
|
|
- api_call_info = {
|
|
|
- "call_id": call_id,
|
|
|
- "start_time": call_start_time.isoformat(),
|
|
|
- "end_time": call_end_time.isoformat(),
|
|
|
- "duration": (call_end_time - call_start_time).total_seconds(),
|
|
|
- "prompt": prompt_content,
|
|
|
- "input_params": {
|
|
|
- "start_id": global_tx_counter,
|
|
|
- "chunk_data": chunk_str
|
|
|
- },
|
|
|
- "llm_result": data_data
|
|
|
- }
|
|
|
- self.api_calls.append(api_call_info)
|
|
|
-
|
|
|
- # 保存API结果到文件 (Markdown格式,更易阅读)
|
|
|
- # 使用运行ID创建独立的文件夹
|
|
|
- run_id = os.environ.get('FLOW_RUN_ID', 'default')
|
|
|
- api_results_dir = f"api_results_{run_id}"
|
|
|
- os.makedirs(api_results_dir, exist_ok=True)
|
|
|
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
- filename = f"{timestamp}_{call_id}.md"
|
|
|
- filepath = os.path.join(api_results_dir, filename)
|
|
|
+ # 有时候 LLM 忘记过滤表头,我们强制在 SQL 外层或提示中处理
|
|
|
+ # 这里的简单做法是假设 SQL 正确,或者在 SQL 后追加 limit 测试
|
|
|
|
|
|
- try:
|
|
|
- with open(filepath, 'w', encoding='utf-8') as f:
|
|
|
- f.write("# 数据转换结果\n\n")
|
|
|
- f.write("## 调用信息\n\n")
|
|
|
- f.write(f"- 调用ID: {call_id}\n")
|
|
|
- f.write(f"- 开始时间: {call_start_time.isoformat()}\n")
|
|
|
- f.write(f"- 结束时间: {call_end_time.isoformat()}\n")
|
|
|
- f.write(f"- 执行时长: {(call_end_time - call_start_time).total_seconds():.2f} 秒\n")
|
|
|
- f.write("\n## 提示词入参\n\n")
|
|
|
- f.write("```\n")
|
|
|
- f.write(api_call_info["prompt"])
|
|
|
- f.write("\n```\n\n")
|
|
|
- f.write("## 输入参数\n\n")
|
|
|
- f.write("```json\n")
|
|
|
- f.write(json.dumps(api_call_info["input_params"], ensure_ascii=False, indent=2))
|
|
|
- f.write("\n```\n\n")
|
|
|
- f.write("## LLM返回结果\n\n")
|
|
|
- f.write("```json\n")
|
|
|
- f.write(json.dumps(api_call_info["llm_result"], ensure_ascii=False, indent=2))
|
|
|
- f.write("\n```\n")
|
|
|
- print(f"[API_RESULT] 保存API结果文件: {filepath}")
|
|
|
- except Exception as e:
|
|
|
- print(f"[ERROR] 保存API结果文件失败: {filepath}, 错误: {str(e)}")
|
|
|
-
|
|
|
- # print(f"💡 LLM 返回数据: {data_data}")
|
|
|
-
|
|
|
- # 兼容处理:LangChain Parser 通常会直接返回 List 或 Dict
|
|
|
- if isinstance(data_data, dict):
|
|
|
- # 尝试寻找 transactions 键,如果没有则假设整个 dict 就是我们要的对象(虽然罕见)
|
|
|
- batch_data = data_data.get("transactions", [data_data])
|
|
|
- # 如果取出来还是 dict (例如单条记录),包一层 list
|
|
|
- if isinstance(batch_data, dict):
|
|
|
- batch_data = [batch_data]
|
|
|
- elif isinstance(data_data, list):
|
|
|
- batch_data = data_data
|
|
|
- else:
|
|
|
- batch_data = []
|
|
|
-
|
|
|
- if batch_data:
|
|
|
- final_list = self._validate_and_reconcile(batch_data)
|
|
|
- final_list = sorted(final_list, key=lambda x: (x['txId']))
|
|
|
- output = io.StringIO()
|
|
|
- createdAtStr = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
- writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL, lineterminator='\n')
|
|
|
-
|
|
|
- print(f"✅ 批次转换成功,包含 {len(final_list)} 条记录。")
|
|
|
-
|
|
|
- for item in final_list:
|
|
|
- writer.writerow([
|
|
|
- item.get("txId", ""),
|
|
|
- item.get("txDate", ""),
|
|
|
- item.get("txTime", ""),
|
|
|
- item.get("txAmount", ""),
|
|
|
- item.get("txDirection", ""),
|
|
|
- item.get("txBalance", ""),
|
|
|
- item.get("txSummary", ""),
|
|
|
- item.get("txCounterparty", ""),
|
|
|
- createdAtStr
|
|
|
- ])
|
|
|
-
|
|
|
- batch_csv_string = output.getvalue()
|
|
|
- csv_content += batch_csv_string
|
|
|
-
|
|
|
- global_tx_counter += len(final_list)
|
|
|
+ print(f"🚀 [SQLite] 执行查询...")
|
|
|
+ cursor.execute(sql_query)
|
|
|
+ results = cursor.fetchall()
|
|
|
|
|
|
- except Exception as e:
|
|
|
- print(f"⚠️ 批次执行失败: {e}")
|
|
|
- finally:
|
|
|
- end_time = time.perf_counter()
|
|
|
- elapsed_time = end_time - start_time
|
|
|
- print(f"⏱️ 执行耗时: {elapsed_time:.2f} 秒")
|
|
|
- print(f"📊 转换结果:共转换 {global_tx_counter - 1} 条数据")
|
|
|
- print(f"✅ 【步骤2 - 标准化转换】 执行完成")
|
|
|
- print(f"⏱️ 执行总耗时:{time.perf_counter() - switch_start_time:.2f} 秒")
|
|
|
- return csv_content
|
|
|
+ print(f"✅ 提取成功,共 {len(results)} 条数据")
|
|
|
+
|
|
|
+ # 6. 导出为 CSV 字符串
|
|
|
+ output = io.StringIO()
|
|
|
+ writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL, lineterminator='\n')
|
|
|
+
|
|
|
+ # 写入标准表头
|
|
|
+ csv_header = ["txId", "txDate", "txTime", "txAmount", "txDirection", "txBalance", "txSummary",
|
|
|
+ "txCounterparty", "createdAt"]
|
|
|
+ writer.writerow(csv_header)
|
|
|
+
|
|
|
+ created_at = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
|
|
+ for row in results:
|
|
|
+ # row 是元组 (id, date, time, amt, dir, bal, sum, counter)
|
|
|
+ # 转换 tuple 为 list 并添加 createdAt
|
|
|
+ row_list = list(row)
|
|
|
+
|
|
|
+ # --- 🆕 新增:txAmount 取绝对值逻辑 ---
|
|
|
+ try:
|
|
|
+ raw_amount = str(row_list[3]).replace(',', '') # 再次确保去除逗号
|
|
|
+ if raw_amount:
|
|
|
+ # 转换为浮点数取绝对值,再转回字符串(或保持 float)
|
|
|
+ row_list[3] = abs(float(raw_amount))
|
|
|
+ except (ValueError, TypeError):
|
|
|
+ # 如果转换失败(例如识别到了文字),保持原样或设为 0.0
|
|
|
+ print(f"⚠️ 金额转换失败: {row_list[3]}")
|
|
|
+ row_list[3] = 0.0
|
|
|
+
|
|
|
+ # 安全性清洗:处理可能的 None
|
|
|
+ row_list = [str(x) if x is not None else "" for x in row_list]
|
|
|
+
|
|
|
+ # 确保只取前8个字段 (以防 LLM 多选了)
|
|
|
+ final_row = row_list[:8] + [created_at]
|
|
|
+ writer.writerow(final_row)
|
|
|
+ return output.getvalue()
|
|
|
+ except sqlite3.Error as e:
|
|
|
+ print(f"❌ SQLite 执行错误: {e}")
|
|
|
+ # 可以在这里做一个重试机制:把错误信息返给 LLM 让它修正 SQL
|
|
|
+ return ""
|
|
|
+ finally:
|
|
|
+ conn.close()
|
|
|
+ print(f"✅ 【步骤2 - 标准化转换】 执行完成")
|
|
|
+ print(f"⏱️ 总耗时: {time.perf_counter() - start_time:.2f} 秒")
|
|
|
+
|
|
|
+ # --- 流程入口 ---
|
|
|
async def parse_and_save_to_file(self, file_path: str, output_dir: str = "output") -> str:
|
|
|
- """
|
|
|
- 供 Workflow 调用:解析并保存文件,返回全路径名
|
|
|
- """
|
|
|
current_script_path = os.path.abspath(__file__)
|
|
|
current_dir = os.path.dirname(current_script_path)
|
|
|
file_full_name = os.path.basename(file_path)
|
|
|
- file_name = os.path.splitext(file_full_name)[0] # 不带后缀 11111
|
|
|
+ file_name = os.path.splitext(file_full_name)[0]
|
|
|
output_dir = os.path.normpath(os.path.join(current_dir, "..", "..", output_dir))
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
@@ -458,17 +532,13 @@ JSON Array:
|
|
|
raise Exception("数据解析失败,未生成有效内容")
|
|
|
|
|
|
async def run_workflow_task(self, input_file_path: str) -> dict:
|
|
|
- """
|
|
|
- 标准 Workflow 入口方法
|
|
|
- """
|
|
|
# 1. 记录开始时间(使用time.perf_counter获取高精度时间)
|
|
|
start_time = time.perf_counter()
|
|
|
print(f"BEGIN---数据标准化任务开始---")
|
|
|
try:
|
|
|
print(f"待执行标准化的文件:{input_file_path}")
|
|
|
- api_results_dir = "data_files"
|
|
|
- saved_path = await self.parse_and_save_to_file(input_file_path, api_results_dir)
|
|
|
-
|
|
|
+ saved_path = await self.parse_and_save_to_file(input_file_path, "data_files")
|
|
|
+ print(f"结果文件保存至:{saved_path}")
|
|
|
return {
|
|
|
"status": "success",
|
|
|
"file_path": saved_path,
|
|
|
@@ -476,10 +546,7 @@ JSON Array:
|
|
|
"timestamp": datetime.datetime.now().isoformat()
|
|
|
}
|
|
|
except Exception as e:
|
|
|
- return {
|
|
|
- "status": "error",
|
|
|
- "message": str(e)
|
|
|
- }
|
|
|
+ return {"status": "error", "message": str(e)}
|
|
|
finally:
|
|
|
end_time = time.perf_counter()
|
|
|
elapsed_time = end_time - start_time
|
|
|
@@ -505,14 +572,16 @@ async def data_standardize(api_key: str, base_url: str, model_name: str, multimo
|
|
|
# --- 运行 ---
|
|
|
async def main():
|
|
|
agent = TransactionParserAgent(
|
|
|
- api_key="sk-8634dbc2866540c4b6003bb5733f23d8",
|
|
|
- multimodal_api_url="http://103.154.31.78:20012/api/file/read"
|
|
|
+ api_key="",
|
|
|
+ multimodal_api_url="http://103.154.31.78:20012/api/file/read",
|
|
|
+ model_name="Qwen3-32B",
|
|
|
+ base_url="http://10.192.72.12:9996/v1",
|
|
|
)
|
|
|
|
|
|
current_script_path = os.path.abspath(__file__)
|
|
|
current_dir = os.path.dirname(current_script_path)
|
|
|
# 模拟 Workflow 传入一个待处理文件
|
|
|
- input_pdf = "data_files/4.pdf"
|
|
|
+ input_pdf = "data_files/11111.png"
|
|
|
filepath = os.path.normpath(os.path.join(current_dir, "..", "..", input_pdf))
|
|
|
|
|
|
if not os.path.exists(filepath):
|