Rean 1 lună în urmă
părinte
comite
b5e3000210

Fișier diff suprimat deoarece este prea mare
+ 1 - 0
陈敬安/theme3/bad_cases_qwen3-0.6b_20250722_102517.csv


+ 763 - 0
陈敬安/theme3/theme3_ex1.ipynb

@@ -0,0 +1,763 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "9e428868",
+   "metadata": {},
+   "source": [
+    "## 练习1:情感分类\n",
+    "\n",
+    "### 数据集: data/acllmdb_sentiment_small\n",
+    "包含样本:\n",
+    "- negative: 包含负面评价120条\n",
+    "- positive: 包含正面评价120条\n",
+    "\n",
+    "### 目标\n",
+    "取出下面每一条文本, 并使用大模型预测每个评价的文本是 positive(正面)/negative(负面),并给出大模型预测的准确率(accuracy).\n",
+    "\n",
+    "Accuracy = 预测正确的样本数量 / 总样本(240)\n",
+    "\n",
+    "### 要求\n",
+    "1. 分别使用大模型(qwen3_4b)三种输出方式来进行这个测试\n",
+    "   1. text: 大模型纯文本生成\n",
+    "   2. json mode / json schema: 大模型结构化输出能力\n",
+    "   3. ⭐️⭐️⭐️(80%)tool choices / function call: 使用大模型工具调用能力\n",
+    "2. 能够正确输出每种方法的 metrics:  \n",
+    "    1. 必须:Accuracy(准确率)\n",
+    "    2. 【非必须】:Precision、Recall、F1 score\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "2fb513d9",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "已设置 OpenAI 客户端,base_url: http://103.154.31.78:20001/compatible-mode/v1/\n"
+     ]
+    }
+   ],
+   "source": [
+    "import os\n",
+    "from dotenv import load_dotenv\n",
+    "from openai import OpenAI\n",
+    "import glob\n",
+    "import json\n",
+    "from tqdm import tqdm\n",
+    "from openai import BadRequestError\n",
+    "\n",
+    "def create_client():\n",
+    "    load_dotenv()\n",
+    "    client = OpenAI(\n",
+    "        base_url=os.getenv(\"BAILIAN_API_BASE_URL\"),\n",
+    "        api_key=os.getenv(\"BAILIAN_API_KEY\")\n",
+    "    )\n",
+    "    return client\n",
+    "\n",
+    "client = create_client()\n",
+    "print(f\"已设置 OpenAI 客户端,base_url: {client.base_url}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5850031f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# 初始化,读取文件\n",
+    "def read_files(directory):\n",
+    "    texts = []\n",
+    "    for filepath in glob.glob(os.path.join(directory, '*')):\n",
+    "        with open(filepath, 'r', encoding='utf-8') as f:\n",
+    "            texts.append(f.read())\n",
+    "    return texts\n",
+    "\n",
+    "positive_texts = read_files('../../data/acllmdb_sentiment_small/positive')\n",
+    "negative_texts = read_files('../../data/acllmdb_sentiment_small/negative')\n",
+    "\n",
+    "# 输出方式\n",
+    "output_modes = {\n",
+    "    \"text\": {\n",
+    "        \"system_message\": \"You are a sentiment analyzer. Reply with only one word: either 'positive' or 'negative'.\",\n",
+    "        \"user_message\": \"Analyze the sentiment of this text: {text}\"\n",
+    "    },\n",
+    "    \"json\": {\n",
+    "        \"system_message\": \"You are a sentiment analyzer. Reply in JSON format with a 'sentiment' field that is either 'positive' or 'negative'.\",\n",
+    "        \"user_message\": \"Analyze the sentiment of this text and return JSON: {text}\"\n",
+    "    },\n",
+    "    \"function\": {\n",
+    "        \"system_message\": \"You are a sentiment analyzer. Use the provided function to analyze sentiment.\",\n",
+    "        \"user_message\": \"Analyze the sentiment of this text: {text}\",\n",
+    "        \"functions\": [\n",
+    "            {\n",
+    "                \"type\": \"function\",\n",
+    "                \"function\":{\n",
+    "                    \"name\": \"analyze_sentiment\",\n",
+    "                    \"description\": \"Analyze the sentiment of a text\",\n",
+    "                    \"parameters\": {\n",
+    "                        \"type\": \"object\",\n",
+    "                        \"properties\": {\n",
+    "                            \"sentiment\": {\n",
+    "                                \"type\": \"string\",\n",
+    "                                \"enum\": [\"positive\", \"negative\"],\n",
+    "                                \"description\": \"The sentiment of the text\"\n",
+    "                            }\n",
+    "                        },\n",
+    "                        \"required\": [\"sentiment\"]\n",
+    "                    },\n",
+    "                    \"strict\": True\n",
+    "                }\n",
+    "            }\n",
+    "        ]\n",
+    "    },\n",
+    "    \"function_reason\": {\n",
+    "        \"system_message\": \"You are a sentiment analyzer. Use the provided function to analyze sentiment and explain your reasoning. You must use the provided function and return the right json format containing two properties, which is sentiment and reasoning, or you will be fired. Specially, you should return reasoning in Chinese.\",\n",
+    "        \"user_message\": \"Analyze the sentiment of this text: {text}\",\n",
+    "        \"functions\": [{\n",
+    "            \"type\": \"function\",\n",
+    "            \"function\": {\n",
+    "                \"name\": \"analyze_sentiment_with_reasoning\",\n",
+    "                \"description\": \"Analyze the sentiment of a text and provide reasoning\",\n",
+    "                \"parameters\": {\n",
+    "                    \"type\": \"object\",\n",
+    "                    \"properties\": {\n",
+    "                        \"sentiment\": {\n",
+    "                            \"type\": \"string\",\n",
+    "                            \"enum\": [\"positive\", \"negative\"],\n",
+    "                            \"description\": \"The sentiment of the text\"\n",
+    "                        },\n",
+    "                        \"reasoning\": {\n",
+    "                            \"type\": \"string\",\n",
+    "                            \"description\": \"Explanation for the sentiment classification in Chinese\"\n",
+    "                        }\n",
+    "                    },\n",
+    "                    \"required\": [\"sentiment\", \"reasoning\"]\n",
+    "                },\n",
+    "                \"strict\": True\n",
+    "            }\n",
+    "        }]\n",
+    "    }\n",
+    "}\n",
+    "\n",
+    "def get_sentiment(client, model, text, mode='text'):\n",
+    "    \"\"\"获取文本情感\n",
+    "    \n",
+    "    Args:\n",
+    "        client: OpenAI客户端\n",
+    "        model: 使用的模型名称\n",
+    "        text: 输入文本\n",
+    "        mode: 输出模式,可选 'text', 'json', 'function'\n",
+    "    \n",
+    "    Returns:\n",
+    "        str: 'positive' 或 'negative'\n",
+    "    \"\"\"\n",
+    "    messages = [\n",
+    "        {\"role\": \"system\", \"content\": output_modes[mode][\"system_message\"]},\n",
+    "        {\"role\": \"user\", \"content\": output_modes[mode][\"user_message\"].format(text=text)}\n",
+    "    ]\n",
+    "    \n",
+    "    if mode == 'function' or mode == 'function_reason':\n",
+    "        response = client.chat.completions.create(\n",
+    "            model=model,\n",
+    "            messages=messages,\n",
+    "            tools=output_modes[mode][\"functions\"],\n",
+    "            extra_body={\"enable_thinking\": False},\n",
+    "        )\n",
+    "        \n",
+    "        # 检查是否有工具调用\n",
+    "        if response.choices[0].message.tool_calls:\n",
+    "            result = json.loads(response.choices[0].message.tool_calls[0].function.arguments)\n",
+    "            if not isinstance(result, dict):\n",
+    "                raise Exception(\"Invalid result format: expected dictionary\")\n",
+    "            \n",
+    "            if \"sentiment\" in result:\n",
+    "                if mode == 'function_reason':\n",
+    "                    return result[\"sentiment\"], result[\"reasoning\"]\n",
+    "                elif mode == 'function':\n",
+    "                    return result[\"sentiment\"]\n",
+    "            \n",
+    "        # 如果没有工具调用,尝试从内容中提取情感\n",
+    "        content = response.choices[0].message.content.lower()\n",
+    "        temp_sentiment = ''\n",
+    "        if 'positive' in content:\n",
+    "            temp_sentiment = 'positive'\n",
+    "        elif 'negative' in content:\n",
+    "            temp_sentiment = 'negative'\n",
+    "        elif 'neutral' in content:\n",
+    "            temp_sentiment = 'neutral'\n",
+    "        elif 'mixed' in content:\n",
+    "            temp_sentiment = 'mixed'\n",
+    "        \n",
+    "        if temp_sentiment:\n",
+    "            if mode == 'function_reason':\n",
+    "                return temp_sentiment, content\n",
+    "            elif mode == 'function':\n",
+    "                return temp_sentiment\n",
+    "        else:\n",
+    "            raise Exception(f\"无法从响应中提取情感, content: {content}\")\n",
+    "    \n",
+    "    response = client.chat.completions.create(\n",
+    "        model=model,\n",
+    "        messages=messages,\n",
+    "        extra_body={\"enable_thinking\": False},\n",
+    "    )\n",
+    "    \n",
+    "    if mode == 'json':\n",
+    "        response = client.chat.completions.create(\n",
+    "            model=model,\n",
+    "            messages=messages,\n",
+    "            response_format={\"type\": \"json_object\"},\n",
+    "            extra_body={\"enable_thinking\": False},\n",
+    "        )\n",
+    "        result = json.loads(response.choices[0].message.content)\n",
+    "        return result[\"sentiment\"]\n",
+    "    \n",
+    "    return response.choices[0].message.content.strip().lower()\n",
+    "\n",
+    "# 计算混淆矩阵\n",
+    "def calculate_metrics(tp, fp, tn, fn, total_samples):\n",
+    "    accuracy = (tp + tn) / total_samples\n",
+    "    precision = tp / (tp + fp) if (tp + fp) > 0 else 0\n",
+    "    recall = tp / (tp + fn) if (tp + fn) > 0 else 0\n",
+    "    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n",
+    "    output = []\n",
+    "\n",
+    "    # 主要指标\n",
+    "    metrics = {\n",
+    "        'accuracy': '准确率',\n",
+    "        'precision': '精确率',\n",
+    "        'recall': '召回率',\n",
+    "        'f1': 'F1分数'\n",
+    "    }\n",
+    "    \n",
+    "    for key, name in metrics.items():\n",
+    "        value = locals()[key]\n",
+    "        output.append(f\"{name}: {value:.2%}\")\n",
+    "    \n",
+    "    # 混淆矩阵\n",
+    "    output.append(\"\\n混淆矩阵:\")\n",
+    "    output.append(f\"真正例 (TP): {tp:4d}\\t假正例 (FP): {fp:4d}\")\n",
+    "    output.append(f\"假负例 (FN): {fn:4d}\\t真负例 (TN): {tn:4d}\")\n",
+    "    \n",
+    "    return output\n",
+    "\n",
+    "def evaluate_sentiment(client, model, positive_texts, negative_texts, mode='text'):\n",
+    "    \"\"\"评估情感分析性能\n",
+    "    \n",
+    "    Args:\n",
+    "        client: OpenAI客户端\n",
+    "        model: 使用的模型名称\n",
+    "        positive_texts: 正面评价文本列表\n",
+    "        negative_texts: 负面评价文本列表\n",
+    "        mode: 输出模式,可选 'text', 'json', 'function'\n",
+    "    \"\"\"\n",
+    "    \n",
+    "    tp = fp = tn = fn = 0\n",
+    "    total_samples = len(positive_texts) + len(negative_texts)\n",
+    "    \n",
+    "    # 处理正面样本\n",
+    "    for text in tqdm(positive_texts, desc=f\"Processing positive samples ({mode} mode)\"):\n",
+    "        try:\n",
+    "            prediction = get_sentiment(client, model, text, mode)\n",
+    "            if prediction == 'positive':\n",
+    "                tp += 1\n",
+    "            else:\n",
+    "                fn += 1\n",
+    "        except BadRequestError as e:\n",
+    "            fn += 1\n",
+    "    \n",
+    "    # 处理负面样本\n",
+    "    for text in tqdm(negative_texts, desc=f\"Processing negative samples ({mode} mode)\"):\n",
+    "        try:\n",
+    "            prediction = get_sentiment(client, model, text, mode)\n",
+    "            if prediction == 'negative':\n",
+    "                tn += 1\n",
+    "            else:\n",
+    "                fp += 1\n",
+    "        except BadRequestError as e:\n",
+    "            fn += 1\n",
+    "    \n",
+    "    # 计算并返回指标\n",
+    "    return calculate_metrics(tp, fp, tn, fn, total_samples)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "id": "735ac5fa",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "开始评估情感分析性能...\n",
+      "使用模型: qwen3-4b, 输出模式: text\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Processing positive samples (text mode): 100%|██████████| 121/121 [00:56<00:00,  2.13it/s]\n",
+      "Processing negative samples (text mode): 100%|██████████| 121/121 [00:57<00:00,  2.11it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "评估结果: {'accuracy': 0.9049586776859504, 'precision': 0.9380530973451328, 'recall': 0.8688524590163934, 'f1': 0.902127659574468, 'confusion_matrix': {'tp': 106, 'fp': 7, 'tn': 113, 'fn': 16}}\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "model = \"qwen3-4b\"\n",
+    "mode = 'text'\n",
+    "print(f\"开始评估情感分析性能...\\n使用模型: {model}, 输出模式: {mode}\")\n",
+    "results = evaluate_sentiment(client, model, positive_texts, negative_texts, mode)\n",
+    "print(f\"评估结果: {results}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "id": "9a3ca72e",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "开始评估情感分析性能...\n",
+      "使用模型: qwen3-4b, 输出模式: json\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Processing positive samples (json mode): 100%|██████████| 121/121 [03:32<00:00,  1.75s/it]\n",
+      "Processing negative samples (json mode): 100%|██████████| 121/121 [03:20<00:00,  1.66s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "评估结果: \n",
+      "['准确率: 91.74%', '精确率: 93.97%', '召回率: 89.34%', 'F1分数: 91.60%', '\\n混淆矩阵:', '真正例 (TP):  109\\t假正例 (FP):    7', '假负例 (FN):   13\\t真负例 (TN):  113']\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "model = \"qwen3-4b\"\n",
+    "mode = 'json'\n",
+    "print(f\"开始评估情感分析性能...\\n使用模型: {model}, 输出模式: {mode}\")\n",
+    "results = evaluate_sentiment(client, model, positive_texts, negative_texts, mode)\n",
+    "print(f\"评估结果: \\n{results}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "id": "a074541e",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "开始评估情感分析性能...\n",
+      "使用模型: qwen3-4b, 输出模式: function\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Processing positive samples (function mode): 100%|██████████| 121/121 [01:43<00:00,  1.17it/s]\n",
+      "Processing negative samples (function mode): 100%|██████████| 121/121 [01:38<00:00,  1.23it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "评估结果: \n",
+      "准确率: 91.74%\n",
+      "精确率: 96.40%\n",
+      "召回率: 86.99%\n",
+      "F1分数: 91.45%\n",
+      "\n",
+      "混淆矩阵:\n",
+      "真正例 (TP):  107\t假正例 (FP):    4\n",
+      "假负例 (FN):   16\t真负例 (TN):  115\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "model = \"qwen3-4b\"\n",
+    "mode = 'function'\n",
+    "print(f\"开始评估情感分析性能...\\n使用模型: {model}, 输出模式: {mode}\")\n",
+    "results = evaluate_sentiment(client, model, positive_texts, negative_texts, mode)\n",
+    "print(\"\\n评估结果: \")\n",
+    "print(\"\\n\".join(results))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "1c4d4ecc",
+   "metadata": {},
+   "source": [
+    "3. 比较以下大模型在这个任务上的 accuracy 差异\n",
+    "   1. qwen3-32b\n",
+    "   2. qwen3-30b-a3b\n",
+    "   3. qwen3-0.6b\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 41,
+   "id": "660d3af7",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "开始比较不同模型的性能...\n",
+      "\n",
+      "评估模型: qwen3-0.6b\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Processing positive samples (function mode): 100%|██████████| 121/121 [00:59<00:00,  2.05it/s]\n",
+      "Processing negative samples (function mode): 100%|██████████| 121/121 [00:56<00:00,  2.15it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "模型准确率比较:\n",
+      "qwen3-0.6b: 79.34%\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "# models = ['qwen3-32b', 'qwen3-30b-a3b', 'qwen3-0.6b']\n",
+    "models = ['qwen3-0.6b']\n",
+    "mode = 'function'\n",
+    "\n",
+    "print(\"开始比较不同模型的性能...\")\n",
+    "model_results = {}\n",
+    "\n",
+    "for model in models:\n",
+    "    print(f\"\\n评估模型: {model}\")\n",
+    "    try:\n",
+    "        results = evaluate_sentiment(client, model, positive_texts, negative_texts, mode)\n",
+    "        accuracy = results[0].split(': ')[1].rstrip('%')\n",
+    "        model_results[model] = float(accuracy)\n",
+    "    except Exception as e:\n",
+    "        print(f\"评估 {model} 时出错: {str(e)}\")\n",
+    "        model_results[model] = None\n",
+    "\n",
+    "print(\"\\n模型准确率比较:\")\n",
+    "for model, accuracy in model_results.items():\n",
+    "    if accuracy is not None:\n",
+    "        print(f\"{model}: {accuracy:.2f}%\")\n",
+    "    else:\n",
+    "        print(f\"{model}: 评估失败\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 42,
+   "id": "5c2257a4",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "开始比较不同模型的性能...\n",
+      "\n",
+      "评估模型: qwen3-32b\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Processing positive samples (function mode): 100%|██████████| 121/121 [02:17<00:00,  1.13s/it]\n",
+      "Processing negative samples (function mode): 100%|██████████| 121/121 [02:19<00:00,  1.15s/it]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "评估模型: qwen3-30b-a3b\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Processing positive samples (function mode): 100%|██████████| 121/121 [01:41<00:00,  1.19it/s]\n",
+      "Processing negative samples (function mode): 100%|██████████| 121/121 [01:39<00:00,  1.21it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "模型准确率比较:\n",
+      "qwen3-32b: 92.98%\n",
+      "qwen3-30b-a3b: 93.80%\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "models = ['qwen3-32b', 'qwen3-30b-a3b']\n",
+    "mode = 'function'\n",
+    "\n",
+    "print(\"开始比较不同模型的性能...\")\n",
+    "model_results = {}\n",
+    "\n",
+    "for model in models:\n",
+    "    print(f\"\\n评估模型: {model}\")\n",
+    "    try:\n",
+    "        results = evaluate_sentiment(client, model, positive_texts, negative_texts, mode)\n",
+    "        accuracy = results[0].split(': ')[1].rstrip('%')\n",
+    "        model_results[model] = float(accuracy)\n",
+    "    except Exception as e:\n",
+    "        print(f\"评估 {model} 时出错: {str(e)}\")\n",
+    "        model_results[model] = None\n",
+    "\n",
+    "print(\"\\n模型准确率比较:\")\n",
+    "for model, accuracy in model_results.items():\n",
+    "    if accuracy is not None:\n",
+    "        print(f\"{model}: {accuracy:.2f}%\")\n",
+    "    else:\n",
+    "        print(f\"{model}: 评估失败\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "cf01d5d7",
+   "metadata": {},
+   "source": [
+    "4. 按照以下步骤分析做 bad cases 分析:\n",
+    "    1. 使用: qwen3-0.6b 模型, 使用tools进行预测\n",
+    "    2. 进行情感分析的时候, 让能够同时给出情感分析结果和原因(大模型输出)\n",
+    "    3. 挑选大模型预测错误的 cases\n",
+    "    4. 人工查看错误 cases大模型分类错误的原因, 并进行归类. 得出每一类错误造成的错误数量.\n",
+    "    5. 根据错误分析结果, 优化prompt, 提升大模型对分类任务的准确率.\n",
+    "    6. 注意temperature参数对结果的影响.\n",
+    "    7. 产出物: csv文件(index,对应文件路径,positive/negative,arguments)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1bad4d36",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import pandas as pd\n",
+    "from datetime import datetime\n",
+    "\n",
+    "def analyze_bad_cases(client, model, positive_texts, negative_texts):\n",
+    "    \"\"\"分析错误案例\n",
+    "    \n",
+    "    Args:\n",
+    "        client: OpenAI客户端\n",
+    "        model: 使用的模型名称\n",
+    "        positive_texts: 正面评价文本列表\n",
+    "        negative_texts: 负面评价文本列表\n",
+    "    \"\"\"\n",
+    "    bad_cases = []\n",
+    "    bad_cases_count=0\n",
+    "    \n",
+    "    # 分析正面样本\n",
+    "    for idx, text in enumerate(tqdm(positive_texts, desc=\"分析正面样本\")):\n",
+    "        try:\n",
+    "            sentiment, reasoning = get_sentiment(client, model, text, mode='function_reason')\n",
+    "            if sentiment != 'positive':\n",
+    "                bad_cases.append({\n",
+    "                    'text_id': f'positive_{idx}',\n",
+    "                    'text': text[:200],  # 只保存前200个字符\n",
+    "                    'expected': 'positive',\n",
+    "                    'predicted': sentiment,\n",
+    "                    'reasoning': reasoning\n",
+    "                })\n",
+    "        except BadRequestError as e:\n",
+    "            bad_cases.append({\n",
+    "                'text_id': f'positive_{idx}',\n",
+    "                'text': text[:200],\n",
+    "                'expected': 'positive',\n",
+    "                'predicted': 'error',\n",
+    "                'reasoning': str(e)\n",
+    "            })\n",
+    "            \n",
+    "    # 保存正面样本的错误案例\n",
+    "    df_positive = pd.DataFrame(bad_cases)\n",
+    "    timestamp = datetime.now().strftime('%Y%m%d')\n",
+    "    filename = f'bad_cases_{model}_{timestamp}.csv'\n",
+    "    df_positive.to_csv(filename, index=False, encoding='utf-8-sig')\n",
+    "    \n",
+    "    # 清空bad_cases列表,准备分析负面样本\n",
+    "    bad_cases_count = bad_cases_count + len(bad_cases)\n",
+    "    bad_cases.clear()\n",
+    "    \n",
+    "    # 分析负面样本\n",
+    "    for idx, text in enumerate(tqdm(negative_texts, desc=\"分析负面样本\")):\n",
+    "        try:\n",
+    "            sentiment, reasoning = get_sentiment(client, model, text, mode='function_reason')\n",
+    "            if sentiment != 'negative':\n",
+    "                bad_cases.append({\n",
+    "                    'text_id': f'negative_{idx}',\n",
+    "                    'text': text[:200],\n",
+    "                    'expected': 'negative',\n",
+    "                    'predicted': sentiment,\n",
+    "                    'reasoning': reasoning\n",
+    "                })\n",
+    "        except BadRequestError as e:\n",
+    "            bad_cases.append({\n",
+    "                'text_id': f'negative_{idx}',\n",
+    "                'text': text[:200],\n",
+    "                'expected': 'negative',\n",
+    "                'predicted': 'error',\n",
+    "                'reasoning': str(e)\n",
+    "            })\n",
+    "    \n",
+    "    # 保存结果到CSV\n",
+    "    df = pd.DataFrame(bad_cases)\n",
+    "    filename = f'bad_cases_{model}_{timestamp}.csv'\n",
+    "    df.to_csv(filename, mode='a', header=False, index=False, encoding='utf-8-sig')\n",
+    "    bad_cases_count = bad_cases_count + len(bad_cases)\n",
+    "    \n",
+    "    print(f\"\\n分析完成!共发现 {len(bad_cases_count)} 个错误案例\")\n",
+    "    print(f\"结果已保存至: {filename}\")\n",
+    "    \n",
+    "    return bad_cases"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "id": "324ec3d8",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "开始进行情感分析...\n",
+      "使用模型: qwen3-0.6b, 输出模式: function_reason\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "分析正面样本: 100%|██████████| 36/36 [00:24<00:00,  1.49it/s]\n",
+      "分析负面样本: 100%|██████████| 36/36 [00:25<00:00,  1.42it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "分析完成!共发现 1 个错误案例\n",
+      "结果已保存至: bad_cases_qwen3-0.6b_20250722.csv\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "model = \"qwen3-0.6b\"\n",
+    "mode = 'function_reason'\n",
+    "print(f\"开始进行情感分析...\\n使用模型: {model}, 输出模式: {mode}\")\n",
+    "bad_cases = analyze_bad_cases(client, model, positive_texts, negative_texts)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "ai-learning",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.11.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}

+ 229 - 0
陈敬安/theme3/theme3_ex2.ipynb

@@ -0,0 +1,229 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "ddbc9058",
+   "metadata": {},
+   "source": [
+    "## 练习2. 信息收集多轮对话机器人\n",
+    "使用 OpenAI 的预测接口(非 chat API),自己维护 chat hisitory 信息,完成多轮对话,完成一个信息收集助手。\n",
+    "\n",
+    "1. 可以收集用户的姓名、年龄和用户感兴趣的行业\n",
+    "2. 不要存储任何中间状态,只记录对话历史\n",
+    "3. 开始后,\n",
+    "    - AI:请输入你的姓名\n",
+    "    - User: xxx\n",
+    "    - AI: 请输入你的年龄\n",
+    "    - User: yyy\n",
+    "    - AI: 请输入你感兴趣的行业\n",
+    "    - User:zzz\n",
+    "    - AI: 收集完成,结果是xxxx\n",
+    "4. 用户中途可以随意输入其他问题,但是 AI 要提醒用户不回答其他问题,然后提醒用户还没收集完成\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "29c715a6",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "已设置 OpenAI 客户端,base_url: http://103.154.31.78:20001/compatible-mode/v1/\n"
+     ]
+    }
+   ],
+   "source": [
+    "import os\n",
+    "from dotenv import load_dotenv\n",
+    "from openai import OpenAI\n",
+    "\n",
+    "def create_client():\n",
+    "    load_dotenv()\n",
+    "    client = OpenAI(\n",
+    "        base_url=os.getenv(\"BAILIAN_API_BASE_URL\"),\n",
+    "        api_key=os.getenv(\"BAILIAN_API_KEY\")\n",
+    "    )\n",
+    "    return client\n",
+    "\n",
+    "client = create_client()\n",
+    "print(f\"已设置 OpenAI 客户端,base_url: {client.base_url}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "4da87a9b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import time\n",
+    "\n",
+    "\n",
+    "questions = {\n",
+    "    'name': '请输入你的姓名',\n",
+    "    'age': '请输入你的年龄',\n",
+    "    'industry': '请输入你感兴趣的行业'\n",
+    "}\n",
+    "\n",
+    "def create_prompt(history, collected_info):\n",
+    "    system_prompt = f\"\"\"\n",
+    "    你是一个信息收集助手。你需要依次收集用户的姓名、年龄和感兴趣的行业。\n",
+    "    如果用户输入其他无关内容,请提醒用户需要先完成信息收集。\n",
+    "\n",
+    "    信息收集策略:\n",
+    "    1. 使用开放式问题引导用户提供更多信息\n",
+    "    2. 对模糊的信息进行澄清和确认\n",
+    "    3. 根据上下文,智能地推断和补充相关信息\n",
+    "    4. 保持对话的自然性和连贯性\n",
+    "    5. 务必保证信息的准确性和完整性\n",
+    "    6. 只有当信息符合需要收集的内容的规则时(例如姓名必须是符合姓名规则的,年龄必须是1~100内的数字等等),才将其记录到已收集信息中\n",
+    "    7. 如果用户输入年龄时给出的是出生年份,请帮用户计算实际年龄(今年是2025年)。\n",
+    "    8. 如果用户输入的年龄没有明确的数字,则此输入无效。\n",
+    "\n",
+    "    对话要求:\n",
+    "    1. 保持友好、专业的语调\n",
+    "    2. 每次回复要简洁明了\n",
+    "    3. 每个信息收集成功后,确认用户输入的内容,格式为:“好的,已经收集到您的(此处为收集的项目)是:”,并询问下一个问题。\n",
+    "    4. 当用户询问一些其他问题时,请礼貌地告诉用户,你是一个信息收集助手,不回答其它不相关问题。提醒用户还没收集完成。\n",
+    "    5. 所有问题回答并收集完毕后,请总结用户的所有信息,并告知用户“信息已收集完毕”。\n",
+    "    6. 如果没有对话历史,请先介绍自己,再进行询问。\n",
+    "    7. 如果用户回答与问题无关,请礼貌且友好地提醒用户需要回答相关问题。\n",
+    "    8. 请不要直接将问题内容直接输出,而是要根据用户的回答进行智能化的处理和回复。\n",
+    "    \"\"\"\n",
+    "\n",
+    "    prompt = system_prompt + \"\\n\\n对话历史:\\n\"\n",
+    "    for h in history:\n",
+    "        prompt += f\"用户:{h['user']}\\nAI:{h['assistant']}\\n\"\n",
+    "    \n",
+    "    if 'name' not in collected_info:\n",
+    "        prompt += f\"\\n下一个问题:{questions['name']}\"\n",
+    "    elif 'age' not in collected_info:\n",
+    "        prompt += f\"\\n下一个问题:{questions['age']}\"\n",
+    "    elif 'industry' not in collected_info:\n",
+    "        prompt += f\"\\n下一个问题:{questions['industry']}\"\n",
+    "    else:\n",
+    "        prompt += \"\\n所有信息已收集完成\"\n",
+    "    \n",
+    "    return prompt\n",
+    "\n",
+    "def chat_completion(prompt, model):\n",
+    "    response = client.chat.completions.create(\n",
+    "        model=model,\n",
+    "        messages=prompt,\n",
+    "        max_tokens=500,\n",
+    "        temperature=0.7,\n",
+    "        stop=[\"用户:\"],  # 防止模型生成用户的发言\n",
+    "        extra_body={\"enable_thinking\": False}\n",
+    "    )\n",
+    "    return response.choices[0].message.content.strip()\n",
+    "\n",
+    "def information_collection(model):\n",
+    "    history = []\n",
+    "    collected_info = {}\n",
+    "    \n",
+    "    # 开始对话\n",
+    "    prompt = create_prompt(history, collected_info)\n",
+    "    messages = [{\"role\": \"system\", \"content\": prompt}]\n",
+    "    response = chat_completion(messages, model)\n",
+    "    print(f\"AI: {response}\")\n",
+    "    time.sleep(1)  # 等待一秒钟,模拟用户思考时间\n",
+    "    \n",
+    "    while True:\n",
+    "        user_input = input(\"用户: \")\n",
+    "        if user_input.lower() in ['exit', 'quit']:\n",
+    "            print(\"对话结束。\")\n",
+    "            break\n",
+    "        if not user_input.strip():\n",
+    "            print(\"输入不能为空,请重新输入。\")\n",
+    "            continue\n",
+    "        print(f\"用户: {user_input}\")\n",
+    "        \n",
+    "        # 记录用户输入\n",
+    "        history.append({\"user\": user_input, \"assistant\": \"\"})\n",
+    "\n",
+    "        # 生成回复\n",
+    "        prompt = create_prompt(history, collected_info)\n",
+    "        messages = [{\"role\": \"system\", \"content\": prompt}, {\"role\": \"user\", \"content\": user_input}]\n",
+    "        response = chat_completion(messages, model)\n",
+    "        if response.startswith(\"好的,已经收集到您的\"):\n",
+    "            # 提取已收集的信息\n",
+    "            if '姓名' in response:\n",
+    "                collected_info['name'] = user_input\n",
+    "            elif '年龄' in response:\n",
+    "                collected_info['age'] = user_input\n",
+    "            elif '行业' in response:\n",
+    "                collected_info['industry'] = user_input\n",
+    "        print(f\"AI: {response}\")\n",
+    "        time.sleep(1)\n",
+    "        \n",
+    "        # 记录AI回复\n",
+    "        history[-1][\"assistant\"] = response\n",
+    "        \n",
+    "        # 如果所有信息都收集完毕,结束对话\n",
+    "        if '收集完毕' in response:\n",
+    "            break"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 37,
+   "id": "808bd704",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "开始进行信息收集...\t使用模型: qwen3-30b-a3b\n",
+      "AI: 您好,我是您的信息收集助手。为了更好地为您提供服务,请您先告诉我您的姓名。\n",
+      "用户: 我是周深\n",
+      "AI: 好的,已经收集到您的姓名是:周深。接下来,请问您的年龄是多少?\n",
+      "用户: 我有点不记得了\n",
+      "AI: 好的,如果您不方便提供年龄,我们可以先进行下一步。请问您感兴趣的行业是什么?\n",
+      "用户: 我感兴趣的行业有很多,我2014年,参加节目《中国好声音第三季》正式出道。 [23]2015年,发行首支个人音乐单曲《玫瑰与小鹿》。 [1]2016年,为动画电影《大鱼海棠》演唱印象曲《大鱼》, [2]并凭该曲斩获多个奖项。 [34-35]2017年,发行首张个人音乐专辑《深的深》。 [4]2019年,在节目《中国梦之声·我们的歌》中夺冠, [112]并举行“C-929星球”全国巡回演唱会。 [13]2020年,作为首发歌手加盟音乐节目《歌手·当打之年》,演唱的歌曲《达拉崩吧》以破亿播放量斩获网易云音乐“云村”播放量第一; [380]同年起,连续两年获腾讯音乐娱乐盛典年度最佳内地男歌手奖。 [346] 2021年,首次登上央视春晚舞台并演唱歌曲《灯火里的中国》; [374]同年,为联合国COP15演唱主题曲《和光同春》。 [312]2022年起,加盟户外竞技真人秀《奔跑吧》担任常驻嘉宾; [431]同年,与最高人民检察院合作推出的单曲《我的答案》获中宣部精神文明建设五个一工程奖。 [501]2023年,携个人单曲《花开忘忧》再度登上央视春晚舞台; [526]同年,获得第30届东方风云榜最佳男歌手奖。 [5]2024年,发行第2张个人音乐专辑《反深代词》, [728]专辑销售额突破6000万; [803]周深9.29Hz巡回演唱会同步启动; [806]同年,应邀前往纽约联合国总部演讲并演唱交响诗《和平颂》。 [676]2025年,开启4大洲5国8城世界巡回演唱会, [812]并参加央视春晚; [822]5月,参与配音的动画电影《时间之子》上映。\n",
+      "AI: 好的,已经收集到您的兴趣行业是:音乐。接下来,请问您的年龄是多少?\n",
+      "用户: 我1992年9月29日出生。\n",
+      "AI: 好的,已经收集到您的出生日期是:1992年9月29日。根据今年是2025年,您的年龄是32岁。信息确认无误吗?\n",
+      "用户: 请问今天天气怎么样?\n",
+      "AI: 抱歉,我是一个信息收集助手,目前需要先完成信息收集。请您回答您感兴趣的行业是什么?\n",
+      "用户: 我不是已经回答过了吗\n",
+      "AI: 是的,您之前提到过您感兴趣的行业是音乐。好的,已经收集到您的兴趣行业是:音乐。信息确认无误吗?\n",
+      "用户: 是的\n",
+      "AI: 好的,已经收集到您的兴趣行业是:音乐。信息确认无误吗?\n",
+      "用户: 是的\n",
+      "AI: 信息已收集完毕。\n"
+     ]
+    }
+   ],
+   "source": [
+    "model = \"qwen3-30b-a3b\"\n",
+    "print(f\"开始进行信息收集...\\t使用模型: {model}\")\n",
+    "information_collection(model)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "ai-learning",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.11.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}

+ 497 - 0
陈敬安/theme4/theme4_ex3.ipynb

@@ -0,0 +1,497 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "4a50f764",
+   "metadata": {},
+   "source": [
+    "### 练习3. 转账Agent\n",
+    "设计一个 Agent, 能够协助用户完成转账操作\n",
+    "\n",
+    "#### 目标\n",
+    "使用 qwen3-32b 模型\n",
+    "1. 向用户收集\n",
+    "    1. 对方姓名(写死只有“张三”存在并返回手机号,其他都不存在)\n",
+    "    2. 对方卡号\n",
+    "    3. 转账金额(只有500)\n",
+    "2. 收集完成后,调用转账接口向对方转账并提示用户转账结果\n",
+    "3. 如果转账失败用户可以重新修改信息并再次发起转账\n",
+    "\n",
+    "#### 要求\n",
+    "1. 给 Agent 提供必要的工具,Agent 来决定调用哪个工具。\n",
+    "2. 但是 agent必须要能够正确的用户给的参数来调用对应的方法\n",
+    "\n",
+    "#### 场景\n",
+    "1. Agent在接收满足上面 tools要求的情况下, 校验成功转账\n",
+    "2. Agent 在接收不满足上面 tools要求的情况下, 和用户澄清并让用户修改信息。用户修改正确后,可以完成转账。\n",
+    "    1. 余额不足,和用户沟通,用户只能转少于500块\n",
+    "    2. 其他:账号手机号不存在,用户可以手动输入手机号\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d6203ac3",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "from dotenv import load_dotenv\n",
+    "\n",
+    "load_dotenv()\n",
+    "base_url = os.getenv('BAILIAN_API_BASE_URL')\n",
+    "api_key = os.getenv('BAILIAN_API_KEY')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "41fd0038",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from agno.agent import Agent\n",
+    "from agno.tools import tool\n",
+    "from agno.memory.v2.db.sqlite import SqliteMemoryDb\n",
+    "from agno.memory.v2.memory import Memory\n",
+    "from agno.storage.sqlite import SqliteStorage\n",
+    "from agno.models.openai.like import OpenAILike\n",
+    "from typing import Optional, Tuple\n",
+    "from dataclasses import dataclass\n",
+    "from enum import Enum\n",
+    "import time\n",
+    "\n",
+    "SYSTEM_PROMPT = \"\"\"\n",
+    "你是一个转账助手,帮助用户完成转账操作。\n",
+    "\n",
+    "规则:\n",
+    "1. 系统中只存在\"张三\"这个用户\n",
+    "2. 最大转账金额为不能超过用户的余额\n",
+    "3. 需要验证收款人姓名、账号和金额\n",
+    "4. 如果转账失败,根据具体错误原因协助用户修改信息:\n",
+    "   - 如果用户不存在,提醒用户检查收款人姓名\n",
+    "   - 如果手机号不匹配,提醒用户确认手机号\n",
+    "   - 如果余额不足,建议用户调整转账金额\n",
+    "5. 如果用户有其他疑问,提供相关帮助\n",
+    "\n",
+    "请按照以下步骤操作:\n",
+    "1. 开场先介绍自己是转账助手,询问用户接下来的操作,如果是转账,则按照下面的步骤;如果是其他问题,则提供相关帮助。\n",
+    "2. 收集用户信息:收款人姓名、账号、转账金额\n",
+    "3. 验证信息的正确性,如果信息不完整或错误,引导用户修改\n",
+    "4. 执行转账操作,如果转账成功,告知用户转账成功\n",
+    "5. 如果失败,引导用户修改信息重试\n",
+    "\"\"\"\n",
+    "\n",
+    "# Mock database for user info\n",
+    "USER_DB = {\n",
+    "    \"张三\": \"13800138000\"\n",
+    "}\n",
+    "\n",
+    "# UserId for the memories\n",
+    "user_id = \"test\"\n",
+    "# Database file for memory and storage\n",
+    "db_file = \"tmp/agent.db\"\n",
+    "\n",
+    "model = OpenAILike(\n",
+    "            id=\"qwen3-32b\",\n",
+    "            api_key=api_key,\n",
+    "            base_url=base_url,\n",
+    "            request_params={\"extra_body\": {\"enable_thinking\": False}},\n",
+    "        )\n",
+    "\n",
+    "# Initialize memory.v2\n",
+    "memory = Memory(\n",
+    "    # Use any model for creating memories\n",
+    "    model=model,\n",
+    "    db=SqliteMemoryDb(table_name=\"user_memories\", db_file=db_file),\n",
+    ")\n",
+    "# Initialize storage\n",
+    "storage = SqliteStorage(table_name=\"agent_sessions\", db_file=db_file)\n",
+    "\n",
+    "# Tools for the agent\n",
+    "def get_contact(user_name: str) -> Optional[str]:\n",
+    "    return USER_DB.get(user_name, None)\n",
+    "\n",
+    "def get_balance() -> float:\n",
+    "    return 500.0\n",
+    "\n",
+    "def reply_to_user(message: str) -> None:\n",
+    "    print(message)\n",
+    "\n",
+    "def validate_amount(amount: float) -> bool:\n",
+    "    return amount <= get_balance()\n",
+    "\n",
+    "class TransferError(Enum):\n",
+    "    USER_NOT_FOUND = \"用户不存在\"\n",
+    "    INVALID_PHONE = \"手机号不匹配\"\n",
+    "    INSUFFICIENT_BALANCE = \"余额不足\"\n",
+    "    NONE = \"无错误\"\n",
+    "\n",
+    "@dataclass\n",
+    "class TransferResult:\n",
+    "    success: bool\n",
+    "    error: TransferError\n",
+    "    message: str\n",
+    "\n",
+    "def validate_transfer(user_name: str, phone: str, amount: float) -> TransferResult:\n",
+    "    \"\"\"验证转账信息并返回详细结果\"\"\"\n",
+    "    if user_name not in USER_DB:\n",
+    "        return TransferResult(False, TransferError.USER_NOT_FOUND, f\"收款人 {user_name} 不存在\")\n",
+    "    \n",
+    "    if not validate_amount(amount):\n",
+    "        return TransferResult(False, TransferError.INSUFFICIENT_BALANCE, \n",
+    "                            f\"转账金额 {amount} 超过账户余额 {get_balance()}\")\n",
+    "    \n",
+    "    if phone != get_contact(user_name):\n",
+    "        return TransferResult(False, TransferError.INVALID_PHONE, \n",
+    "                            f\"手机号 {phone} 与用户信息不匹配\")\n",
+    "    \n",
+    "    return TransferResult(True, TransferError.NONE, \"验证通过\")\n",
+    "\n",
+    "def transfer(user_name: str, phone: str, amount: float) -> Tuple[bool, str]:\n",
+    "    \"\"\"执行转账操作,返回转账结果和详细信息\"\"\"\n",
+    "    result = validate_transfer(user_name, phone, amount)\n",
+    "    if result.success:\n",
+    "        return True, \"转账成功\"\n",
+    "    return False, result.message\n",
+    "\n",
+    "agent = Agent(\n",
+    "    model=model,\n",
+    "    tools=[validate_amount, transfer, reply_to_user, get_contact, get_balance],\n",
+    "    instructions=SYSTEM_PROMPT,\n",
+    "    add_datetime_to_instructions=True,\n",
+    "    show_tool_calls=True,\n",
+    "    markdown=True,\n",
+    "    memory=memory,\n",
+    "    enable_user_memories=True,\n",
+    "    storage=storage,\n",
+    "    add_history_to_messages=True,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "id": "fb2117b1",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "转账助手已启动。请开始对话。\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a95695bf39f84f5b8c9cdb1251f7995f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fdfb2b68c45b44619135d51ba709421a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4e508c5b43a846748cb26de3897265e3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fbff122df31749959d4b2af37d9e579e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5237ded427d64f20a9157511cafe0c6a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "daaa1ff32aeb4f98b0c738dc555cf327",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "608035432a5d44039c32628a1fc8d460",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a548a49f4dc0491abb69679c71b4d46a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e745199fedab4c898e105ee1b3311a25",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "对话结束。\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(\"转账助手已启动。请开始对话。\")\n",
+    "agent.print_response(\"你好\", stream=True, user_id=user_id)\n",
+    "while True:\n",
+    "    user_input = input(\"用户: \")\n",
+    "    if user_input.lower() in ['exit', 'quit']:\n",
+    "        print(\"对话结束。\")\n",
+    "        break\n",
+    "    if not user_input.strip():\n",
+    "        print(\"输入不能为空,请重新输入。\")\n",
+    "        continue\n",
+    "    # print(f\"用户: {user_input}\")\n",
+    "\n",
+    "    agent.print_response(user_input, stream=True, user_id=user_id)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "id": "95207ed5",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Memories about 上个会话:\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">[</span>\n",
+       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">UserMemory</span><span style=\"font-weight: bold\">(</span>\n",
+       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"color: #808000; text-decoration-color: #808000\">memory</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'Her phone number is 13800138000.'</span>,\n",
+       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"color: #808000; text-decoration-color: #808000\">topics</span>=<span style=\"font-weight: bold\">[</span><span style=\"color: #008000; text-decoration-color: #008000\">'contact'</span><span style=\"font-weight: bold\">]</span>,\n",
+       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"color: #808000; text-decoration-color: #808000\">input</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'手机号:13800138000'</span>,\n",
+       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"color: #808000; text-decoration-color: #808000\">last_updated</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">datetime</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">.datetime</span><span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2025</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">22</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">15</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">55</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">33</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">469818</span><span style=\"font-weight: bold\">)</span>,\n",
+       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   │   </span><span style=\"color: #808000; text-decoration-color: #808000\">memory_id</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'e9590a07-edbe-410b-b336-c454654dbd8f'</span>\n",
+       "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│   </span><span style=\"font-weight: bold\">)</span>\n",
+       "<span style=\"font-weight: bold\">]</span>\n",
+       "</pre>\n"
+      ],
+      "text/plain": [
+       "\u001b[1m[\u001b[0m\n",
+       "\u001b[2;32m│   \u001b[0m\u001b[1;35mUserMemory\u001b[0m\u001b[1m(\u001b[0m\n",
+       "\u001b[2;32m│   │   \u001b[0m\u001b[33mmemory\u001b[0m=\u001b[32m'Her phone number is 13800138000.'\u001b[0m,\n",
+       "\u001b[2;32m│   │   \u001b[0m\u001b[33mtopics\u001b[0m=\u001b[1m[\u001b[0m\u001b[32m'contact'\u001b[0m\u001b[1m]\u001b[0m,\n",
+       "\u001b[2;32m│   │   \u001b[0m\u001b[33minput\u001b[0m=\u001b[32m'手机号:13800138000'\u001b[0m,\n",
+       "\u001b[2;32m│   │   \u001b[0m\u001b[33mlast_updated\u001b[0m=\u001b[1;35mdatetime\u001b[0m\u001b[1;35m.datetime\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m2025\u001b[0m, \u001b[1;36m7\u001b[0m, \u001b[1;36m22\u001b[0m, \u001b[1;36m15\u001b[0m, \u001b[1;36m55\u001b[0m, \u001b[1;36m33\u001b[0m, \u001b[1;36m469818\u001b[0m\u001b[1m)\u001b[0m,\n",
+       "\u001b[2;32m│   │   \u001b[0m\u001b[33mmemory_id\u001b[0m=\u001b[32m'e9590a07-edbe-410b-b336-c454654dbd8f'\u001b[0m\n",
+       "\u001b[2;32m│   \u001b[0m\u001b[1m)\u001b[0m\n",
+       "\u001b[1m]\u001b[0m\n"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "from rich.pretty import pprint\n",
+    "\n",
+    "print(\"Memories about 上个会话:\")\n",
+    "pprint(memory.get_user_memories(user_id=user_id))"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "ai-learning",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.11.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}

BIN
陈敬安/theme4/tmp/agent.db


Unele fișiere nu au fost afișate deoarece prea multe fișiere au fost modificate în acest diff