589 字
2 分钟
Agent 实战:从零构建一个研究助手 Agent
理论学了很多,是时候动手做一个真正的 Agent 了。
本文将带你从零构建一个「研究助手 Agent」,它能自动搜索信息、分析数据、撰写报告——一个真正能帮你干活的生产力工具。
从需求到部署,完整实战。
本文要点
- 项目需求分析与设计
- 技术栈选择与架构
- 工具集成:搜索、数据库、文件处理
- 记忆系统实现
- 多 Agent 协作设计
- 完整代码实现
- 部署与监控方案
一、项目需求分析
1.1 功能需求
┌─────────────────────────────────────────────────────────────┐│ 研究助手 Agent 功能需求 │├─────────────────────────────────────────────────────────────┤│ ││ 核心功能 ││ ├── 信息搜集:搜索网络、查询数据库、读取文档 ││ ├── 数据分析:统计计算、趋势分析、对比分析 ││ ├── 内容生成:撰写报告、生成摘要、制作图表 ││ └── 持续学习:记住用户偏好、积累研究经验 ││ ││ 交互方式 ││ ├── 自然语言输入:用户描述研究主题 ││ ├── 多轮对话:支持追问和细化需求 ││ ├── 结果导出:生成 PDF、Word、Markdown 格式 ││ └── 进度反馈:实时显示研究进展 ││ ││ 非功能需求 ││ ├── 响应时间:< 30 秒完成简单研究 ││ ├── 准确性:信息来源可追溯、引用可验证 ││ ├── 可扩展:易于添加新的数据源和分析工具 ││ └── 安全性:敏感信息保护、操作日志记录 ││ │└─────────────────────────────────────────────────────────────┘1.2 典型使用场景
flowchart LR
A[用户:研究某公司] --> B[搜索公开信息]
B --> C[查询财务数据]
C --> D[分析竞争格局]
D --> E[生成研究报告]
F[用户:分析销售趋势] --> G[查询销售数据库]
G --> H[统计计算]
H --> I[生成图表]
I --> J[撰写分析报告]
二、技术栈选择
2.1 架构设计
flowchart TB
subgraph 用户层
A[Web UI / CLI]
end
subgraph Agent 层
B[主控 Agent]
C[研究员 Agent]
D[分析师 Agent]
E[编辑 Agent]
end
subgraph 能力层
F[工具管理器]
G[记忆系统]
H[向量数据库]
end
subgraph 数据层
I[搜索 API]
J[数据库]
K[文件系统]
end
A --> B
B --> C
B --> D
B --> E
C --> F
D --> F
E --> F
F --> I
F --> J
F --> K
B --> G
G --> H
2.2 技术选型
┌─────────────────────────────────────────────────────────────┐│ 技术栈选型 │├─────────────────────────────────────────────────────────────┤│ ││ 核心框架 ││ ├── LangChain:Agent 编排和工具管理 ││ ├── LangGraph:工作流状态管理 ││ └── OpenAI GPT-4o:主要 LLM ││ ││ 向量数据库 ││ ├── ChromaDB:本地开发和测试 ││ └── 可切换到 Milvus/Pinecone:生产部署 ││ ││ 数据存储 ││ ├── SQLite:轻量级数据缓存 ││ └── Redis:会话状态管理 ││ ││ API 集成 ││ ├── SerperAPI:网络搜索 ││ ├── Tavily:深度搜索 ││ └── 自定义 API:内部数据源 ││ ││ 部署运维 ││ ├── FastAPI:REST API 服务 ││ ├── Docker:容器化部署 ││ └── Prometheus + Grafana:监控告警 ││ │└─────────────────────────────────────────────────────────────┘三、项目结构
research-agent/├── src/│ ├── agents/│ │ ├── __init__.py│ │ ├── base.py # Agent 基类│ │ ├── researcher.py # 研究员 Agent│ │ ├── analyst.py # 分析师 Agent│ │ └── editor.py # 编辑 Agent│ ││ ├── tools/│ │ ├── __init__.py│ │ ├── search.py # 搜索工具│ │ ├── database.py # 数据库工具│ │ ├── file_processor.py # 文件处理工具│ │ └── calculator.py # 计算工具│ ││ ├── memory/│ │ ├── __init__.py│ │ ├── short_term.py # 短期记忆│ │ ├── long_term.py # 长期记忆│ │ └── episodic.py # 情景记忆│ ││ ├── workflows/│ │ ├── __init__.py│ │ └── research_flow.py # 研究工作流│ ││ ├── utils/│ │ ├── __init__.py│ │ ├── logger.py # 日志工具│ │ └── config.py # 配置管理│ ││ └── api/│ ├── __init__.py│ ├── main.py # FastAPI 入口│ └── routes.py # API 路由│├── tests/│ └── test_agents.py│├── config/│ ├── settings.yaml # 配置文件│ └── prompts/ # 提示词模板│├── data/│ └── chroma/ # 向量数据库存储│├── requirements.txt├── Dockerfile└── README.md四、工具集成实现
4.1 搜索工具
from typing import List, Dict, Optionalfrom langchain_core.tools import toolimport requestsimport os
class SearchTools: """搜索工具集"""
def __init__(self): self.serper_api_key = os.getenv("SERPER_API_KEY") self.tavily_api_key = os.getenv("TAVILY_API_KEY")
@tool def web_search(self, query: str, num_results: int = 5) -> List[Dict]: """ 搜索网络获取信息。
Args: query: 搜索关键词 num_results: 返回结果数量
Returns: 搜索结果列表,包含标题、链接、摘要 """ url = "https://google.serper.dev/search" headers = { "X-API-KEY": self.serper_api_key, "Content-Type": "application/json" } payload = {"q": query, "num": num_results}
response = requests.post(url, json=payload, headers=headers) results = response.json().get("organic", [])
return [ { "title": r.get("title", ""), "link": r.get("link", ""), "snippet": r.get("snippet", ""), "position": r.get("position", 0) } for r in results[:num_results] ]
@tool def deep_search(self, query: str, search_depth: str = "basic") -> Dict: """ 深度搜索,获取更全面的信息。
Args: query: 搜索关键词 search_depth: 搜索深度 (basic/advanced)
Returns: 包含搜索结果和答案的字典 """ from tavily import TavilyClient
client = TavilyClient(api_key=self.tavily_api_key) result = client.search( query=query, search_depth=search_depth, max_results=10 )
return { "answer": result.get("answer", ""), "results": result.get("results", []), "follow_up_questions": result.get("follow_up_questions", []) }
@tool def fetch_webpage(self, url: str) -> str: """ 获取网页内容。
Args: url: 网页 URL
Returns: 网页文本内容 """ from bs4 import BeautifulSoup
response = requests.get(url, timeout=30) soup = BeautifulSoup(response.text, 'html.parser')
# 移除脚本和样式 for script in soup(["script", "style"]): script.decompose()
text = soup.get_text(separator='\n', strip=True)
# 清理多余空白 lines = [line.strip() for line in text.splitlines() if line.strip()] return '\n'.join(lines[:200]) # 限制长度
# 创建工具实例search_tools = SearchTools()4.2 数据库工具
from typing import List, Dict, Any, Optionalfrom langchain_core.tools import toolimport sqlite3import json
class DatabaseTools: """数据库工具集"""
def __init__(self, db_path: str = "./data/research.db"): self.db_path = db_path self._init_database()
def _init_database(self): """初始化数据库表""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor()
# 研究记录表 cursor.execute(""" CREATE TABLE IF NOT EXISTS research_records ( id INTEGER PRIMARY KEY AUTOINCREMENT, topic TEXT NOT NULL, query TEXT, result TEXT, sources TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """)
# 用户偏好表 cursor.execute(""" CREATE TABLE IF NOT EXISTS user_preferences ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id TEXT, preference_key TEXT, preference_value TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, UNIQUE(user_id, preference_key) ) """)
conn.commit() conn.close()
@tool def execute_query(self, sql: str) -> List[Dict]: """ 执行 SQL 查询。
Args: sql: SELECT 查询语句
Returns: 查询结果列表 """ if not sql.strip().upper().startswith("SELECT"): return {"error": "仅支持 SELECT 查询"}
conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row
try: cursor = conn.execute(sql) results = [dict(row) for row in cursor.fetchall()] return results except Exception as e: return {"error": str(e)} finally: conn.close()
@tool def save_research(self, topic: str, query: str, result: str, sources: List[str]) -> int: """ 保存研究记录。
Args: topic: 研究主题 query: 查询内容 result: 研究结果 sources: 信息来源列表
Returns: 记录 ID """ conn = sqlite3.connect(self.db_path) cursor = conn.cursor()
cursor.execute(""" INSERT INTO research_records (topic, query, result, sources) VALUES (?, ?, ?, ?) """, (topic, query, result, json.dumps(sources)))
record_id = cursor.lastrowid conn.commit() conn.close()
return record_id
@tool def get_research_history(self, topic: Optional[str] = None, limit: int = 10) -> List[Dict]: """ 获取研究历史记录。
Args: topic: 可选的主题筛选 limit: 返回记录数量
Returns: 历史记录列表 """ conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row
if topic: cursor = conn.execute(""" SELECT * FROM research_records WHERE topic LIKE ? ORDER BY created_at DESC LIMIT ? """, (f"%{topic}%", limit)) else: cursor = conn.execute(""" SELECT * FROM research_records ORDER BY created_at DESC LIMIT ? """, (limit,))
results = [dict(row) for row in cursor.fetchall()] conn.close()
return results
@tool def save_preference(self, user_id: str, key: str, value: str) -> bool: """保存用户偏好""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor()
cursor.execute(""" INSERT OR REPLACE INTO user_preferences (user_id, preference_key, preference_value) VALUES (?, ?, ?) """, (user_id, key, value))
conn.commit() conn.close()
return True
@tool def get_preferences(self, user_id: str) -> Dict[str, str]: """获取用户偏好""" conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row
cursor = conn.execute(""" SELECT preference_key, preference_value FROM user_preferences WHERE user_id = ? """, (user_id,))
preferences = {row["preference_key"]: row["preference_value"] for row in cursor.fetchall()} conn.close()
return preferences4.3 文件处理工具
from typing import List, Dict, Optionalfrom langchain_core.tools import toolimport osimport jsonfrom datetime import datetime
class FileTools: """文件处理工具集"""
def __init__(self, output_dir: str = "./output"): self.output_dir = output_dir os.makedirs(output_dir, exist_ok=True)
@tool def save_report(self, filename: str, content: str, format: str = "markdown") -> str: """ 保存研究报告。
Args: filename: 文件名(不含扩展名) content: 报告内容 format: 输出格式 (markdown/json/txt)
Returns: 保存的文件路径 """ ext_map = { "markdown": ".md", "json": ".json", "txt": ".txt" }
ext = ext_map.get(format, ".txt") filepath = os.path.join(self.output_dir, f"{filename}{ext}")
with open(filepath, 'w', encoding='utf-8') as f: f.write(content)
return filepath
@tool def read_file(self, filepath: str) -> str: """ 读取文件内容。
Args: filepath: 文件路径
Returns: 文件内容 """ if not os.path.exists(filepath): return f"错误:文件不存在 {filepath}"
with open(filepath, 'r', encoding='utf-8') as f: return f.read()
@tool def list_files(self, directory: str = None, pattern: str = None) -> List[str]: """ 列出目录下的文件。
Args: directory: 目录路径,默认为输出目录 pattern: 文件名模式(支持通配符)
Returns: 文件列表 """ dir_path = directory or self.output_dir
if not os.path.exists(dir_path): return []
files = [] for f in os.listdir(dir_path): if pattern: import fnmatch if fnmatch.fnmatch(f, pattern): files.append(os.path.join(dir_path, f)) else: files.append(os.path.join(dir_path, f))
return sorted(files, key=lambda x: os.path.getmtime(x), reverse=True)
@tool def generate_markdown_report(self, title: str, sections: Dict[str, str], metadata: Dict = None) -> str: """ 生成 Markdown 格式报告。
Args: title: 报告标题 sections: 各章节内容 {章节名: 内容} metadata: 元数据(作者、日期等)
Returns: Markdown 格式的报告内容 """ lines = []
# 标题 lines.append(f"# {title}\n")
# 元数据 if metadata: lines.append("---") for key, value in metadata.items(): lines.append(f"{key}: {value}") lines.append("---\n")
# 目录 lines.append("## 目录\n") for i, section_name in enumerate(sections.keys(), 1): anchor = section_name.lower().replace(" ", "-") lines.append(f"{i}. [{section_name}](#{anchor})") lines.append("\n")
# 内容 for section_name, content in sections.items(): lines.append(f"## {section_name}\n") lines.append(content) lines.append("\n")
# 时间戳 lines.append(f"\n---\n*生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*")
return "\n".join(lines)4.4 计算工具
from typing import List, Dict, Unionfrom langchain_core.tools import toolimport statistics
class CalculatorTools: """计算工具集"""
@tool def calculate(self, expression: str) -> Union[float, str]: """ 执行数学计算。
Args: expression: 数学表达式,如 "2 + 3 * 4"
Returns: 计算结果 """ try: # 安全计算(仅允许数学运算) allowed_chars = set('0123456789+-*/.() ') if not all(c in allowed_chars for c in expression): return "错误:表达式包含非法字符"
result = eval(expression) return result except Exception as e: return f"计算错误:{str(e)}"
@tool def analyze_statistics(self, numbers: List[float]) -> Dict: """ 计算统计指标。
Args: numbers: 数值列表
Returns: 统计结果字典 """ if not numbers: return {"error": "数据为空"}
return { "count": len(numbers), "sum": sum(numbers), "mean": statistics.mean(numbers), "median": statistics.median(numbers), "min": min(numbers), "max": max(numbers), "std_dev": statistics.stdev(numbers) if len(numbers) > 1 else 0, "variance": statistics.variance(numbers) if len(numbers) > 1 else 0 }
@tool def calculate_growth_rate(self, start_value: float, end_value: float, periods: int = 1) -> Dict: """ 计算增长率。
Args: start_value: 起始值 end_value: 结束值 periods: 周期数
Returns: 增长率信息 """ total_growth = (end_value - start_value) / start_value * 100 avg_growth = total_growth / periods cagr = ((end_value / start_value) ** (1 / periods) - 1) * 100
return { "total_growth_rate": f"{total_growth:.2f}%", "average_growth_rate": f"{avg_growth:.2f}%", "cagr": f"{cagr:.2f}%" }
@tool def compare_values(self, values: Dict[str, float]) -> Dict: """ 对比多个数值。
Args: values: {名称: 数值} 的字典
Returns: 对比结果 """ sorted_values = sorted(values.items(), key=lambda x: x[1], reverse=True) max_name, max_val = sorted_values[0] min_name, min_val = sorted_values[-1]
return { "ranking": [{"name": k, "value": v} for k, v in sorted_values], "max": {"name": max_name, "value": max_val}, "min": {"name": min_name, "value": min_val}, "range": max_val - min_val }五、记忆系统实现
5.1 短期记忆
from typing import List, Dictfrom collections import deque
class ShortTermMemory: """短期记忆:对话历史管理"""
def __init__(self, max_messages: int = 50, max_tokens: int = 8000): self.max_messages = max_messages self.max_tokens = max_tokens self.messages: deque = deque(maxlen=max_messages) self.summary = ""
def add_message(self, role: str, content: str): """添加消息""" self.messages.append({ "role": role, "content": content, "token_count": self._estimate_tokens(content) })
# 检查是否需要压缩 if self._get_total_tokens() > self.max_tokens: self._compress()
def get_context(self, include_summary: bool = True) -> List[Dict]: """获取上下文""" context = []
if include_summary and self.summary: context.append({ "role": "system", "content": f"[历史摘要] {self.summary}" })
context.extend(list(self.messages)) return context
def clear(self): """清空记忆""" self.messages.clear() self.summary = ""
def _estimate_tokens(self, text: str) -> int: """估算 token 数量""" # 简单估算:中文约 1.5 字/token,英文约 4 字符/token return len(text) // 2
def _get_total_tokens(self) -> int: """获取总 token 数""" return sum(m["token_count"] for m in self.messages)
def _compress(self): """压缩历史消息""" # 保留最近 10 条,其余生成摘要 if len(self.messages) <= 10: return
to_compress = list(self.messages)[:-10]
# 这里可以用 LLM 生成摘要 # 简化处理:直接拼接关键信息 compressed_summary = " | ".join([ f"{m['role']}: {m['content'][:100]}..." for m in to_compress[-5:] ])
self.summary = f"{self.summary}\n{compressed_summary}" if self.summary else compressed_summary
# 只保留最近的消息 self.messages = deque(list(self.messages)[-10:], maxlen=self.max_messages)5.2 长期记忆
from typing import List, Dict, Optionalimport chromadbfrom chromadb.config import Settingsfrom openai import OpenAIimport os
class LongTermMemory: """长期记忆:向量数据库存储"""
def __init__(self, collection_name: str = "research_memory", persist_dir: str = "./data/chroma"): # 初始化 ChromaDB self.client = chromadb.Client(Settings( chroma_db_impl="duckdb+parquet", persist_directory=persist_dir ))
self.collection = self.client.get_or_create_collection( name=collection_name, metadata={"hnsw:space": "cosine"} )
self.embedder = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def store(self, content: str, metadata: Dict = None) -> str: """存储记忆""" import uuid
memory_id = str(uuid.uuid4()) embedding = self._get_embedding(content)
self.collection.add( ids=[memory_id], embeddings=[embedding], documents=[content], metadatas=[metadata or {}] )
return memory_id
def recall(self, query: str, n_results: int = 5, filter_metadata: Dict = None) -> List[Dict]: """检索相关记忆""" query_embedding = self._get_embedding(query)
results = self.collection.query( query_embeddings=[query_embedding], n_results=n_results, where=filter_metadata, include=["documents", "metadatas", "distances"] )
memories = [] for i in range(len(results["ids"][0])): memories.append({ "id": results["ids"][0][i], "content": results["documents"][0][i], "metadata": results["metadatas"][0][i], "distance": results["distances"][0][i] })
return memories
def _get_embedding(self, text: str) -> List[float]: """生成 embedding""" response = self.embedder.embeddings.create( model="text-embedding-3-small", input=text ) return response.data[0].embedding
def store_research_experience(self, topic: str, query: str, result: str, lessons: str = None): """存储研究经验""" content = f"""研究主题:{topic}查询内容:{query}研究结果:{result}经验总结:{lessons or '无'}""" self.store( content, metadata={ "type": "research_experience", "topic": topic, "timestamp": datetime.now().isoformat() } )
def get_similar_research(self, topic: str, limit: int = 3) -> List[Dict]: """获取相似研究""" return self.recall( topic, n_results=limit, filter_metadata={"type": "research_experience"} )六、多 Agent 协作实现
6.1 Agent 基类
from abc import ABC, abstractmethodfrom typing import List, Dict, Any, Optionalfrom langchain_openai import ChatOpenAIfrom langchain_core.messages import HumanMessage, SystemMessage, AIMessagefrom langchain_core.tools import BaseTool
class BaseAgent(ABC): """Agent 基类"""
def __init__(self, name: str, llm: ChatOpenAI, tools: List[BaseTool] = None, system_prompt: str = None):
self.name = name self.llm = llm self.tools = tools or [] self.system_prompt = system_prompt or self._default_system_prompt()
# 绑定工具 if self.tools: self.llm_with_tools = self.llm.bind_tools(self.tools) else: self.llm_with_tools = self.llm
@abstractmethod def _default_system_prompt(self) -> str: """默认系统提示词""" pass
@abstractmethod def process(self, input_data: Any) -> Any: """处理任务""" pass
def _call_llm(self, messages: List) -> str: """调用 LLM""" full_messages = [SystemMessage(content=self.system_prompt)] + messages response = self.llm_with_tools.invoke(full_messages) return response
def _execute_tools(self, tool_calls: List) -> List[Dict]: """执行工具调用""" results = []
for tool_call in tool_calls: tool_name = tool_call["name"] tool_args = tool_call["args"]
# 查找工具 tool = next((t for t in self.tools if t.name == tool_name), None)
if tool: try: result = tool.invoke(tool_args) results.append({ "tool": tool_name, "result": result, "success": True }) except Exception as e: results.append({ "tool": tool_name, "error": str(e), "success": False })
return results6.2 研究员 Agent
from typing import List, Dict, Anyfrom langchain_openai import ChatOpenAIfrom langchain_core.tools import BaseToolfrom .base import BaseAgent
class ResearcherAgent(BaseAgent): """研究员 Agent:负责信息收集"""
def __init__(self, llm: ChatOpenAI, tools: List[BaseTool]): super().__init__( name="Researcher", llm=llm, tools=tools )
def _default_system_prompt(self) -> str: return """你是一位专业的研究员,负责收集和整理信息。
你的职责:1. 根据研究主题,确定需要搜集的信息类型2. 使用搜索工具获取相关信息3. 验证信息的可靠性4. 整理和归类收集到的信息
工作原则:- 信息来源要可靠- 尽可能全面- 标注信息来源- 发现矛盾时要标注"""
def process(self, topic: str, depth: str = "normal") -> Dict: """ 执行研究任务。
Args: topic: 研究主题 depth: 研究深度 (quick/normal/deep)
Returns: 研究结果 """ # 第一步:规划研究 plan = self._plan_research(topic, depth)
# 第二步:执行搜索 collected_info = [] for query in plan["queries"]: search_result = self._search(query) collected_info.append(search_result)
# 第三步:整理结果 organized = self._organize_info(collected_info)
return { "topic": topic, "plan": plan, "collected_info": collected_info, "organized_info": organized, "sources": plan["sources"] }
def _plan_research(self, topic: str, depth: str) -> Dict: """规划研究步骤""" prompt = f"""研究主题:{topic}研究深度:{depth}
请规划研究步骤:1. 需要搜索哪些方面的信息?2. 使用什么关键词搜索?3. 预期获取什么类型的数据?
以 JSON 格式返回:{{ "aspects": ["方面1", "方面2", ...], "queries": ["搜索词1", "搜索词2", ...], "expected_data": ["数据类型1", ...]}}""" response = self.llm.invoke([HumanMessage(content=prompt)]) # 解析 JSON... return self._parse_plan(response.content)
def _search(self, query: str) -> Dict: """执行搜索""" messages = [ HumanMessage(content=f"搜索:{query}") ]
response = self._call_llm(messages)
# 处理工具调用 if hasattr(response, 'tool_calls') and response.tool_calls: tool_results = self._execute_tools(response.tool_calls) return { "query": query, "results": tool_results }
return {"query": query, "results": []}
def _organize_info(self, collected_info: List) -> Dict: """整理收集的信息""" prompt = f"""请整理以下收集到的信息:
{collected_info}
要求:1. 按主题分类2. 去除重复信息3. 标注来源4. 突出关键信息""" response = self.llm.invoke([HumanMessage(content=prompt)]) return {"summary": response.content}6.3 分析师 Agent
from typing import List, Dict, Anyfrom langchain_openai import ChatOpenAIfrom langchain_core.tools import BaseToolfrom .base import BaseAgent
class AnalystAgent(BaseAgent): """分析师 Agent:负责数据分析"""
def __init__(self, llm: ChatOpenAI, tools: List[BaseTool]): super().__init__( name="Analyst", llm=llm, tools=tools )
def _default_system_prompt(self) -> str: return """你是一位专业的数据分析师,负责分析和解读数据。
你的职责:1. 分析研究数据2. 发现趋势和模式3. 计算关键指标4. 得出分析结论
分析原则:- 数据驱动- 结论有依据- 考虑多种可能性- 识别数据局限性"""
def process(self, research_data: Dict, analysis_type: str = "comprehensive") -> Dict: """ 执行分析任务。
Args: research_data: 研究数据 analysis_type: 分析类型 (quick/comprehensive/deep)
Returns: 分析结果 """ # 第一步:理解数据 understanding = self._understand_data(research_data)
# 第二步:执行分析 if analysis_type == "quick": analysis = self._quick_analysis(research_data) else: analysis = self._comprehensive_analysis(research_data)
# 第三步:生成洞察 insights = self._generate_insights(analysis)
return { "understanding": understanding, "analysis": analysis, "insights": insights }
def _quick_analysis(self, data: Dict) -> Dict: """快速分析""" prompt = f"""对以下数据进行快速分析,提取关键信息:
{data}
请提供:1. 核心发现(3-5 点)2. 关键数据3. 初步结论""" response = self.llm.invoke([HumanMessage(content=prompt)]) return {"analysis": response.content}
def _comprehensive_analysis(self, data: Dict) -> Dict: """深度分析""" # 使用工具进行计算 # 生成趋势分析 # 对比分析 pass
def _generate_insights(self, analysis: Dict) -> List[str]: """生成洞察""" prompt = f"""基于以下分析,生成关键洞察:
{analysis}
要求:1. 突出最重要的发现2. 解释发现的意义3. 提出可能的行动建议""" response = self.llm.invoke([HumanMessage(content=prompt)]) return response.content.split('\n')6.4 编辑 Agent
from typing import List, Dict, Anyfrom langchain_openai import ChatOpenAIfrom langchain_core.tools import BaseToolfrom .base import BaseAgent
class EditorAgent(BaseAgent): """编辑 Agent:负责报告撰写"""
def __init__(self, llm: ChatOpenAI, tools: List[BaseTool]): super().__init__( name="Editor", llm=llm, tools=tools )
def _default_system_prompt(self) -> str: return """你是一位专业的技术作家和编辑,负责撰写研究报告。
你的职责:1. 组织报告结构2. 撰写清晰的内容3. 确保逻辑连贯4. 优化语言表达
写作原则:- 结构清晰- 语言简洁- 论点有据- 易于理解"""
def process(self, research_result: Dict, analysis_result: Dict, format: str = "markdown") -> Dict: """ 撰写研究报告。
Args: research_result: 研究结果 analysis_result: 分析结果 format: 输出格式
Returns: 报告内容 """ # 第一步:规划报告结构 outline = self._create_outline(research_result, analysis_result)
# 第二步:撰写各章节 sections = self._write_sections(outline, research_result, analysis_result)
# 第三步:整合和润色 report = self._finalize_report(sections, format)
return { "outline": outline, "sections": sections, "report": report }
def _create_outline(self, research: Dict, analysis: Dict) -> List[str]: """创建报告大纲""" prompt = f"""根据研究和分析结果,创建报告大纲:
研究内容:{research}分析结果:{analysis}
请提供报告大纲,包括:1. 摘要2. 背景介绍3. 研究方法4. 主要发现5. 分析结论6. 建议""" response = self.llm.invoke([HumanMessage(content=prompt)]) return response.content.split('\n')
def _write_sections(self, outline: List, research: Dict, analysis: Dict) -> Dict: """撰写各章节""" sections = {}
for section in outline: if section.strip(): prompt = f"""撰写章节:{section}
参考资料:研究数据:{research}分析结果:{analysis}
要求:- 内容详实- 数据准确- 逻辑清晰""" response = self.llm.invoke([HumanMessage(content=prompt)]) sections[section] = response.content
return sections
def _finalize_report(self, sections: Dict, format: str) -> str: """整合报告""" if format == "markdown": lines = [] for section, content in sections.items(): lines.append(f"## {section}\n") lines.append(content) lines.append("\n") return "\n".join(lines)
return str(sections)七、工作流编排
from typing import Dict, Anyfrom langgraph.graph import StateGraph, ENDfrom langchain_openai import ChatOpenAI
from ..agents.researcher import ResearcherAgentfrom ..agents.analyst import AnalystAgentfrom ..agents.editor import EditorAgentfrom ..tools.search import SearchToolsfrom ..tools.database import DatabaseToolsfrom ..tools.file_processor import FileToolsfrom ..tools.calculator import CalculatorToolsfrom ..memory.short_term import ShortTermMemoryfrom ..memory.long_term import LongTermMemory
# 定义状态class ResearchState(dict): """研究工作流状态""" topic: str depth: str research_result: Dict analysis_result: Dict report: str messages: list
class ResearchWorkflow: """研究工作流"""
def __init__(self, openai_api_key: str): # 初始化 LLM self.llm = ChatOpenAI( model="gpt-4o", temperature=0.7, api_key=openai_api_key )
# 初始化工具 self.search_tools = SearchTools() self.db_tools = DatabaseTools() self.file_tools = FileTools() self.calc_tools = CalculatorTools()
# 初始化记忆 self.short_term_memory = ShortTermMemory() self.long_term_memory = LongTermMemory()
# 初始化 Agents self.researcher = ResearcherAgent( self.llm, [self.search_tools.web_search, self.search_tools.deep_search] )
self.analyst = AnalystAgent( self.llm, [self.calc_tools.calculate, self.calc_tools.analyze_statistics] )
self.editor = EditorAgent( self.llm, [self.file_tools.save_report, self.file_tools.generate_markdown_report] )
# 构建工作流图 self.graph = self._build_graph()
def _build_graph(self) -> StateGraph: """构建状态图""" workflow = StateGraph(ResearchState)
# 添加节点 workflow.add_node("research", self._research_node) workflow.add_node("analyze", self._analyze_node) workflow.add_node("write", self._write_node) workflow.add_node("save", self._save_node)
# 定义边 workflow.add_edge("research", "analyze") workflow.add_edge("analyze", "write") workflow.add_edge("write", "save") workflow.add_edge("save", END)
# 设置入口 workflow.set_entry_point("research")
return workflow.compile()
def _research_node(self, state: ResearchState) -> ResearchState: """研究节点""" result = self.researcher.process(state["topic"], state.get("depth", "normal")) state["research_result"] = result
# 保存到短期记忆 self.short_term_memory.add_message( "assistant", f"研究完成:{result['organized_info']['summary']}" )
return state
def _analyze_node(self, state: ResearchState) -> ResearchState: """分析节点""" result = self.analyst.process( state["research_result"], state.get("analysis_type", "comprehensive") ) state["analysis_result"] = result
self.short_term_memory.add_message( "assistant", f"分析完成:{result['insights']}" )
return state
def _write_node(self, state: ResearchState) -> ResearchState: """撰写节点""" result = self.editor.process( state["research_result"], state["analysis_result"], state.get("format", "markdown") ) state["report"] = result["report"]
self.short_term_memory.add_message( "assistant", f"报告撰写完成" )
return state
def _save_node(self, state: ResearchState) -> ResearchState: """保存节点""" # 保存到数据库 self.db_tools.save_research( topic=state["topic"], query=state["topic"], result=state["report"], sources=state["research_result"].get("sources", []) )
# 保存到长期记忆 self.long_term_memory.store_research_experience( topic=state["topic"], query=state["topic"], result=state["report"] )
# 保存文件 filepath = self.file_tools.save_report( filename=f"research_{state['topic'][:20]}", content=state["report"], format="markdown" )
state["output_file"] = filepath
return state
def run(self, topic: str, depth: str = "normal", analysis_type: str = "comprehensive", format: str = "markdown") -> Dict: """ 执行研究工作流。
Args: topic: 研究主题 depth: 研究深度 analysis_type: 分析类型 format: 输出格式
Returns: 研究结果 """ initial_state = ResearchState( topic=topic, depth=depth, analysis_type=analysis_type, format=format, messages=[] )
final_state = self.graph.invoke(initial_state)
return { "topic": topic, "report": final_state["report"], "research_result": final_state["research_result"], "analysis_result": final_state["analysis_result"], "output_file": final_state.get("output_file") }八、API 服务
from fastapi import FastAPI, HTTPException, BackgroundTasksfrom fastapi.middleware.cors import CORSMiddlewarefrom pydantic import BaseModelfrom typing import Optional, Listimport osfrom dotenv import load_dotenv
from ..workflows.research_flow import ResearchWorkflow
load_dotenv()
app = FastAPI( title="Research Agent API", description="AI 研究助手 API", version="1.0.0")
# CORSapp.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],)
# 初始化工作流workflow = ResearchWorkflow(os.getenv("OPENAI_API_KEY"))
# 请求模型class ResearchRequest(BaseModel): topic: str depth: Optional[str] = "normal" analysis_type: Optional[str] = "comprehensive" format: Optional[str] = "markdown"
class ResearchResponse(BaseModel): success: bool topic: str report: Optional[str] = None output_file: Optional[str] = None error: Optional[str] = None
# 任务存储tasks = {}
@app.post("/api/research", response_model=ResearchResponse)async def create_research(request: ResearchRequest, background_tasks: BackgroundTasks): """创建研究任务""" import uuid
task_id = str(uuid.uuid4())
# 后台执行 background_tasks.add_task( run_research_task, task_id, request )
return ResearchResponse( success=True, topic=request.topic )
@app.get("/api/research/{task_id}")async def get_research(task_id: str): """获取研究结果""" if task_id not in tasks: raise HTTPException(status_code=404, detail="Task not found")
return tasks[task_id]
@app.get("/api/history")async def get_history(limit: int = 10): """获取研究历史""" from ..tools.database import DatabaseTools db = DatabaseTools() return db.get_research_history(limit=limit)
def run_research_task(task_id: str, request: ResearchRequest): """执行研究任务""" try: result = workflow.run( topic=request.topic, depth=request.depth, analysis_type=request.analysis_type, format=request.format )
tasks[task_id] = { "status": "completed", "result": result } except Exception as e: tasks[task_id] = { "status": "failed", "error": str(e) }
if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)九、部署与监控
9.1 Docker 配置
# Dockerfile
FROM python:3.11-slim
WORKDIR /app
# 安装依赖COPY requirements.txt .RUN pip install --no-cache-dir -r requirements.txt
# 复制代码COPY . .
# 创建数据目录RUN mkdir -p /app/data /app/output
# 暴露端口EXPOSE 8000
# 启动命令CMD ["python", "-m", "src.api.main"]version: "3.8"
services: research-agent: build: . ports: - "8000:8000" environment: - OPENAI_API_KEY=${OPENAI_API_KEY} - SERPER_API_KEY=${SERPER_API_KEY} - TAVILY_API_KEY=${TAVILY_API_KEY} volumes: - ./data:/app/data - ./output:/app/output depends_on: - redis
redis: image: redis:alpine ports: - "6379:6379"
prometheus: image: prom/prometheus ports: - "9090:9090" volumes: - ./prometheus.yml:/etc/prometheus/prometheus.yml9.2 监控配置
from prometheus_client import Counter, Histogram, Gaugeimport timefrom functools import wraps
# 定义指标RESEARCH_COUNT = Counter( 'research_total', 'Total number of research tasks')
RESEARCH_DURATION = Histogram( 'research_duration_seconds', 'Time spent on research tasks', buckets=[10, 30, 60, 120, 300, 600])
ACTIVE_RESEARCH = Gauge( 'active_research', 'Number of active research tasks')
TOOL_CALLS = Counter( 'tool_calls_total', 'Total number of tool calls', ['tool_name', 'status'])
def monitor_research(func): """研究任务监控装饰器""" @wraps(func) def wrapper(*args, **kwargs): ACTIVE_RESEARCH.inc() RESEARCH_COUNT.inc()
start_time = time.time() try: result = func(*args, **kwargs) return result finally: duration = time.time() - start_time RESEARCH_DURATION.observe(duration) ACTIVE_RESEARCH.dec()
return wrapper
def track_tool_call(tool_name: str, success: bool): """记录工具调用""" status = "success" if success else "error" TOOL_CALLS.labels(tool_name=tool_name, status=status).inc()十、使用示例
10.1 命令行使用
import asynciofrom src.workflows.research_flow import ResearchWorkflowimport os
async def main(): workflow = ResearchWorkflow(os.getenv("OPENAI_API_KEY"))
# 执行研究 result = workflow.run( topic="AI Agent 在企业中的应用趋势", depth="deep", analysis_type="comprehensive" )
print("=" * 50) print("研究报告") print("=" * 50) print(result["report"]) print("=" * 50) print(f"输出文件: {result['output_file']}")
if __name__ == "__main__": asyncio.run(main())10.2 API 调用
# 创建研究任务curl -X POST http://localhost:8000/api/research \ -H "Content-Type: application/json" \ -d '{ "topic": "2024 年 AI 行业发展趋势", "depth": "normal", "analysis_type": "comprehensive" }'
# 获取结果curl http://localhost:8000/api/research/{task_id}
# 获取历史curl http://localhost:8000/api/history?limit=10常见问题 FAQ
Q1:如何处理大量搜索请求?
A:使用异步处理和缓存:
- 搜索结果缓存到 Redis
- 使用后台任务队列
- 限制并发请求数
Q2:如何提高报告质量?
A:
- 增加信息来源验证
- 使用更强的模型(GPT-4)
- 多轮迭代优化
- 添加人工审核环节
Q3:如何控制成本?
A:
- 使用更便宜的模型处理简单任务
- 缓存常用查询结果
- 限制搜索深度和次数
- 监控 token 使用量
Q4:如何扩展新的数据源?
A:
- 创建新的工具类
- 实现
@tool装饰的函数 - 在 Agent 中注册工具
- 更新工作流配置
小结
本文从零构建了一个完整的研究助手 Agent,涵盖:
┌─────────────────────────────────────────────────────────────┐│ 项目总结 │├─────────────────────────────────────────────────────────────┤│ ││ 架构设计:多 Agent 协作 + 状态机工作流 ││ ││ 工具集成:搜索、数据库、文件、计算四大类 ││ ││ 记忆系统:短期对话记忆 + 长期向量存储 ││ ││ API 服务:FastAPI + 后台任务 + Docker 部署 ││ ││ 可观测性:Prometheus 指标 + 日志记录 ││ │└─────────────────────────────────────────────────────────────┘这是一个可运行、可扩展的生产级 Agent 项目框架。
下篇预告
《Agent 未来展望:AGI 之路》
探讨 Agent 技术的未来发展方向和挑战。
参考资料
支持与分享
如果这篇文章对你有帮助,欢迎支持作者或分享给更多人
Agent 实战:从零构建一个研究助手 Agent
https://blog.souloss.com/posts/machine-learning/agent-guide/agent-practical-project/ 部分信息可能已经过时
相关文章 智能推荐
1
AI Agent 实战指南
AI AI Agent 实战指南系列——从基础概念到项目实战,手把手教你构建智能体应用,涵盖最新大模型架构创新。
2
从Chatbot到Agent:打造能自主干活的AI
AI 从Chatbot到Agent——打造能自主干活的AI
3
Agent 测试策略:从单元到集成
AI 深度解读 Agent 测试——单元测试、集成测试、LLM-as-judge、模糊测试
4
Agent 评估体系:如何衡量 Agent 的能力
AI 深度解读 Agent 评估体系——任务完成率、工具调用准确率、成本效率等多维度评估框架
5
Agent 可观测性:日志、追踪与调试
AI 深度解读 Agent 可观测性——Langfuse、OpenTelemetry 追踪、LangSmith 等工具






