LangGraph Human-In-The-Loop机制实战指南
LangGraph 的 interrupt 到底怎么用?直接看实际场景:当 Agent 在执行过程中需要问用户一个问题,然后等着用户回答再继续往下走——这就是 Human-In-The-Loop 最典型的实现方式。下面就用一个完整的例子,把整个流程拆开讲清楚。
核心概念
interrupt() 函数
interrupt() 是 LangGraph 中实现 human-in-the-loop 的核心函数,它的作用就是在节点内部“喊停”,然后等着用户把答案送回来。
具体怎么工作的?
第一次调用:抛出
GraphInterrupt异常,图执行立刻暂停。异常里带着你传进去的value(通常是一个问题或提示),同时图的状态被 checkpointer 保存下来。恢复后调用:当用户通过
Command(resume=...)把值传回来,节点会从头重新执行(注意,不是从中断点继续,而是从节点开头重跑)。interrupt()检测到 resume 值之后,就不再抛异常了,直接返回这个值。
几个要点需要注意:
- 必须启用 checkpointer:
interrupt()依赖 checkpointer 保存状态,没有它不行。 - 节点会重新执行:恢复时从节点开头重跑,不是原地继续。这在设计逻辑时一定要心中有数。
- 支持多个中断:一个节点里可以连续调用多个
interrupt(),恢复时按顺序匹配 resume 值。
Command 类
Command 是用来控制图执行和恢复的指令对象,主要用途就是恢复被 interrupt() 打断的执行。
主要参数:
resume:恢复中断的值。- 单个值:
Command(resume="answer")恢复下一个中断 - 字典映射:
Command(resume={interrupt_id: "answer"})恢复指定ID的中断
- 单个值:
update:更新图状态。Command(update={"key": "value"})goto:跳转到指定节点。Command(goto="node_name")
Checkpointer(检查点)
Checkpointer 是 LangGraph 的持久化层,负责保存和恢复图的状态,也是 interrupt() 能够正常工作的基础。
核心作用就三件事:
- 保存图状态:每个执行步骤(superstep)都会保存一个状态快照。
- 支持恢复执行:可以从任意检查点恢复。
- 管理多个会话:通过
thread_id区分不同的执行会话。
基本用法:
from langgraph.checkpoint.memory import InMemorySa ver
# 创建 checkpointer
checkpointer = InMemorySa ver()
# 编译图时启用 checkpointer
graph = builder.compile(checkpointer=checkpointer)
# 执行时需要提供 thread_id
config = {"configurable": {"thread_id": "thread_1"}}
graph.stream(input, config)
为什么 interrupt() 需要 checkpointer?答案很直接:interrupt() 暂停时得把当前状态存下来,恢复时再从检查点读出来。没有 checkpointer,状态就丢了,没法恢复。
常用的实现:
InMemorySa ver:内存存储,适合开发和测试。PostgresSa ver:PostgreSQL 存储,生产环境标配。SqliteSa ver:SQLite 存储,轻量应用场景。
Thread 和 Checkpoint ID:
thread_id:会话标识符,区分不同的执行会话,必需参数。checkpoint_id:检查点标识符,用于从特定检查点恢复,可选参数。
# 基本配置
config = {"configurable": {"thread_id": "user_123"}}
# 从特定检查点恢复
config = {"configurable": {"thread_id": "user_123", "checkpoint_id": "checkpoint_abc"}}
完整示例
下面的示例展示了一个完整流程:Agent 执行过程中调用 interrupt() 向用户提问,然后监听中断事件,自动获取用户输入,再恢复执行。看代码之前先理一下思路——整个流程其实就是一个递归循环:收到中断 → 拿答案 → 放回去重跑,直到没有中断为止。
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.types import Command, interrupt
from langgraph.checkpoint.memory import InMemorySa ver
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from langchain_anthropic import ChatAnthropic
from pydantic import BaseModel
import uuid
# 定义工具和模型
@tool
def search(query: str):
return f"搜索结果: {query}"
class AskHuman(BaseModel):
question: str
model = ChatAnthropic(model="claude-3-5-sonnet-latest")
model = model.bind_tools([search, AskHuman])
# 定义节点
def call_model(state):
messages = state["messages"]
response = model.invoke(messages)
return {"messages": [response]}
def ask_human(state):
tool_call = state["messages"][-1].tool_calls[0]
ask = AskHuman.model_validate(tool_call["args"])
answer = interrupt(ask.question) # 中断执行,等待用户输入
return {"messages": [{"tool_call_id": tool_call["id"],
"type": "tool",
"content": answer}]}
# 构建图
workflow = StateGraph(MessagesState)
workflow.add_node("agent", call_model)
workflow.add_node("tools", ToolNode([search]))
workflow.add_node("ask_human", ask_human)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges(
"agent",
lambda state: ("ask_human"
if state["messages"][-1].tool_calls
and state["messages"][-1].tool_calls[0]["name"] == "AskHuman"
else "tools"
if state["messages"][-1].tool_calls
else END)
)
workflow.add_edge("tools", "agent")
workflow.add_edge("ask_human", "agent")
app = workflow.compile(checkpointer=InMemorySa ver())
# 获取用户输入的方法(可从数据库、API、消息队列等获取)
def get_user_input(question: str, interrupt_id: str) -> str:
user_inputs = {
"Where are you located?": "San Francisco",
"What is your name?": "Alice",
"What is your age?": "25"
}
return user_inputs.get(question, "Unknown")
# 监听中断并自动恢复执行
def run_with_auto_resume(app, initial_input, config):
for event in app.stream(initial_input, config, stream_mode="updates"):
if "__interrupt__" in event:
interrupt_info = event["__interrupt__"][0]
user_answer = get_user_input(interrupt_info.value, interrupt_info.id)
return run_with_auto_resume(app, Command(resume=user_answer), config)
return []
# 使用
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
initial_input = {"messages": [("user", "Ask the user where they are, then look up the weather there")]}
run_with_auto_resume(app, initial_input, config)
工作原理
执行流程
- 初始执行:图开始跑,Agent 调用工具,路由到
ask_human节点,interrupt()被调用并抛出异常。 - 检测中断:监听流事件,检测到
__interrupt__键,提取中断信息(问题+中断ID)。 - 获取用户输入:调用
get_user_input()方法,从外部系统拿到答案。 - 恢复执行:用
Command(resume=user_answer)恢复,递归调用继续处理。 - 完成执行:没有更多中断了,返回所有事件。
关键机制
中断检测:
if "__interrupt__" in event:
interrupt_info = event["__interrupt__"][0]
# interrupt_info.value: 问题或提示信息
# interrupt_info.id: 中断的唯一标识符
恢复执行:
Command(resume=user_answer)
# 单个值
Command(resume={interrupt_id: user_answer})
# 字典映射(多个中断)
递归处理:
def run_with_auto_resume(app, initial_input, config):
for event in app.stream(initial_input, config):
if "__interrupt__" in event:
# 获取用户输入并递归恢复
return run_with_auto_resume(app, Command(resume=user_answer), config)
return []
实际应用
从数据库获取用户输入
def get_user_input(question: str, interrupt_id: str) -> str:
import sqlite3
conn = sqlite3.connect('user_inputs.db')
cursor = conn.cursor()
cursor.execute("SELECT answer FROM user_inputs WHERE interrupt_id = ?", (interrupt_id,))
result = cursor.fetchone()
conn.close()
if result:
return result[0]
else:
raise ValueError(f"未找到中断ID {interrupt_id} 对应的用户输入")
从 API 获取用户输入
def get_user_input(question: str, interrupt_id: str) -> str:
import requests
response = requests.post("https://api.example.com/get-user-input",
json={"interrupt_id": interrupt_id, "question": question})
if response.status_code == 200:
return response.json()["answer"]
else:
raise ValueError(f"API 调用失败: {response.status_code}")
从消息队列获取用户输入
def get_user_input(question: str, interrupt_id: str) -> str:
import redis
import time
r = redis.Redis(host='localhost', port=6379, db=0)
# 轮询等待用户输入
while True:
answer = r.get(f"user_input:{interrupt_id}")
if answer:
return answer.decode('utf-8')
time.sleep(0.1) # 等待100ms后重试
从 WebSocket 获取用户输入
def get_user_input(question: str, interrupt_id: str) -> str:
import asyncio
import websockets
import json
async def wait_for_input():
async with websockets.connect("ws://localhost:8765") as websocket:
await websocket.send(json.dumps({"interrupt_id": interrupt_id,
"question": question}))
response = await websocket.recv()
return json.loads(response)["answer"]
return asyncio.run(wait_for_input())
最佳实践
1. 监听中断事件
通过检查流事件中的 __interrupt__ 键来检测中断:
for event in app.stream(input, config, stream_mode="updates"):
if "__interrupt__" in event:
interrupt_info = event["__interrupt__"][0]
# 处理中断
2. 处理多个中断
如果图执行过程中可能触发多次中断,建议用循环代替递归,避免无限递归的风险:
def run_with_auto_resume(app, initial_input, config, max_iterations=10):
iteration = 0
while iteration < max_iterations:
iteration += 1
interrupt_detected = False
input_data = initial_input if iteration == 1 else Command(resume=user_answer)
for event in app.stream(input_data, config, stream_mode="updates"):
if "__interrupt__" in event:
interrupt_info = event["__interrupt__"][0]
user_answer = get_user_input(interrupt_info.value, interrupt_info.id)
interrupt_detected = True
break
if not interrupt_detected:
return events
raise RuntimeError(f"达到最大迭代次数 {max_iterations}")
3. 错误处理
生产环境一定要加超时和异常处理,别让系统一直傻等:
def get_user_input(question: str, interrupt_id: str, timeout: float = 30.0) -> str:
import time
start_time = time.time()
while time.time() - start_time < timeout:
try:
# 尝试获取用户输入
answer = fetch_from_external_system(interrupt_id)
if answer:
return answer
except Exception as e:
logger.error(f"获取用户输入失败: {e}")
time.sleep(0.1)
raise TimeoutError(f"获取用户输入超时: {timeout}秒")
4. 状态管理
用 thread_id 管理不同用户的执行会话:
# 为每个用户创建独立的 thread_id
config = {"configurable": {"thread_id": f"user_{user_id}"}}
# 可以从特定检查点恢复
config = {"configurable": {"thread_id": f"user_{user_id}", "checkpoint_id": checkpoint_id}}
