LangGraph Human-In-The-Loop机制实战指南

2026-06-19阅读 0热度 0
Loop

LangGraph 的 interrupt 到底怎么用?直接看实际场景:当 Agent 在执行过程中需要问用户一个问题,然后等着用户回答再继续往下走——这就是 Human-In-The-Loop 最典型的实现方式。下面就用一个完整的例子,把整个流程拆开讲清楚。

LangGraph: Human-In-The-Loop 实现机制

核心概念

interrupt() 函数

interrupt() 是 LangGraph 中实现 human-in-the-loop 的核心函数,它的作用就是在节点内部“喊停”,然后等着用户把答案送回来。

具体怎么工作的?

  1. 第一次调用:抛出 GraphInterrupt 异常,图执行立刻暂停。异常里带着你传进去的 value(通常是一个问题或提示),同时图的状态被 checkpointer 保存下来。

  2. 恢复后调用:当用户通过 Command(resume=...) 把值传回来,节点会从头重新执行(注意,不是从中断点继续,而是从节点开头重跑)。interrupt() 检测到 resume 值之后,就不再抛异常了,直接返回这个值。

几个要点需要注意:

  • 必须启用 checkpointerinterrupt() 依赖 checkpointer 保存状态,没有它不行。
  • 节点会重新执行:恢复时从节点开头重跑,不是原地继续。这在设计逻辑时一定要心中有数。
  • 支持多个中断:一个节点里可以连续调用多个 interrupt(),恢复时按顺序匹配 resume 值。

Command 类

Command 是用来控制图执行和恢复的指令对象,主要用途就是恢复被 interrupt() 打断的执行。

主要参数:

  • resume:恢复中断的值。

    • 单个值:Command(resume="answer") 恢复下一个中断
    • 字典映射:Command(resume={interrupt_id: "answer"}) 恢复指定ID的中断
  • update:更新图状态。

    Command(update={"key": "value"})

  • goto:跳转到指定节点。

    Command(goto="node_name")

Checkpointer(检查点)

Checkpointer 是 LangGraph 的持久化层,负责保存和恢复图的状态,也是 interrupt() 能够正常工作的基础。

核心作用就三件事:

  1. 保存图状态:每个执行步骤(superstep)都会保存一个状态快照。
  2. 支持恢复执行:可以从任意检查点恢复。
  3. 管理多个会话:通过 thread_id 区分不同的执行会话。

基本用法:

from langgraph.checkpoint.memory import InMemorySa ver

# 创建 checkpointer
checkpointer = InMemorySa ver()

# 编译图时启用 checkpointer
graph = builder.compile(checkpointer=checkpointer)

# 执行时需要提供 thread_id
config = {"configurable": {"thread_id": "thread_1"}}
graph.stream(input, config)

为什么 interrupt() 需要 checkpointer?答案很直接:interrupt() 暂停时得把当前状态存下来,恢复时再从检查点读出来。没有 checkpointer,状态就丢了,没法恢复。

常用的实现:

  • InMemorySa ver:内存存储,适合开发和测试。
  • PostgresSa ver:PostgreSQL 存储,生产环境标配。
  • SqliteSa ver:SQLite 存储,轻量应用场景。

Thread 和 Checkpoint ID:

  • thread_id:会话标识符,区分不同的执行会话,必需参数。
  • checkpoint_id:检查点标识符,用于从特定检查点恢复,可选参数。
# 基本配置
config = {"configurable": {"thread_id": "user_123"}}

# 从特定检查点恢复
config = {"configurable": {"thread_id": "user_123", "checkpoint_id": "checkpoint_abc"}}

完整示例

下面的示例展示了一个完整流程:Agent 执行过程中调用 interrupt() 向用户提问,然后监听中断事件,自动获取用户输入,再恢复执行。看代码之前先理一下思路——整个流程其实就是一个递归循环:收到中断 → 拿答案 → 放回去重跑,直到没有中断为止。

from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.types import Command, interrupt
from langgraph.checkpoint.memory import InMemorySa ver
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from langchain_anthropic import ChatAnthropic
from pydantic import BaseModel
import uuid

# 定义工具和模型
@tool
def search(query: str):
    return f"搜索结果: {query}"

class AskHuman(BaseModel):
    question: str

model = ChatAnthropic(model="claude-3-5-sonnet-latest")
model = model.bind_tools([search, AskHuman])

# 定义节点
def call_model(state):
    messages = state["messages"]
    response = model.invoke(messages)
    return {"messages": [response]}

def ask_human(state):
    tool_call = state["messages"][-1].tool_calls[0]
    ask = AskHuman.model_validate(tool_call["args"])
    answer = interrupt(ask.question)              # 中断执行,等待用户输入
    return {"messages": [{"tool_call_id": tool_call["id"],
                          "type": "tool",
                          "content": answer}]}

# 构建图
workflow = StateGraph(MessagesState)
workflow.add_node("agent", call_model)
workflow.add_node("tools", ToolNode([search]))
workflow.add_node("ask_human", ask_human)

workflow.add_edge(START, "agent")
workflow.add_conditional_edges(
    "agent",
    lambda state: ("ask_human"
                   if state["messages"][-1].tool_calls
                   and state["messages"][-1].tool_calls[0]["name"] == "AskHuman"
                   else "tools"
                   if state["messages"][-1].tool_calls
                   else END)
)
workflow.add_edge("tools", "agent")
workflow.add_edge("ask_human", "agent")

app = workflow.compile(checkpointer=InMemorySa ver())

# 获取用户输入的方法(可从数据库、API、消息队列等获取)
def get_user_input(question: str, interrupt_id: str) -> str:
    user_inputs = {
        "Where are you located?": "San Francisco",
        "What is your name?": "Alice",
        "What is your age?": "25"
    }
    return user_inputs.get(question, "Unknown")

# 监听中断并自动恢复执行
def run_with_auto_resume(app, initial_input, config):
    for event in app.stream(initial_input, config, stream_mode="updates"):
        if "__interrupt__" in event:
            interrupt_info = event["__interrupt__"][0]
            user_answer = get_user_input(interrupt_info.value, interrupt_info.id)
            return run_with_auto_resume(app, Command(resume=user_answer), config)
    return []

# 使用
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
initial_input = {"messages": [("user", "Ask the user where they are, then look up the weather there")]}
run_with_auto_resume(app, initial_input, config)

工作原理

执行流程

  1. 初始执行:图开始跑,Agent 调用工具,路由到 ask_human 节点,interrupt() 被调用并抛出异常。
  2. 检测中断:监听流事件,检测到 __interrupt__ 键,提取中断信息(问题+中断ID)。
  3. 获取用户输入:调用 get_user_input() 方法,从外部系统拿到答案。
  4. 恢复执行:用 Command(resume=user_answer) 恢复,递归调用继续处理。
  5. 完成执行:没有更多中断了,返回所有事件。

关键机制

中断检测:

if "__interrupt__" in event:
    interrupt_info = event["__interrupt__"][0]
    # interrupt_info.value: 问题或提示信息
    # interrupt_info.id: 中断的唯一标识符

恢复执行:

Command(resume=user_answer)
# 单个值
Command(resume={interrupt_id: user_answer})
# 字典映射(多个中断)

递归处理:

def run_with_auto_resume(app, initial_input, config):
    for event in app.stream(initial_input, config):
        if "__interrupt__" in event:
            # 获取用户输入并递归恢复
            return run_with_auto_resume(app, Command(resume=user_answer), config)
    return []

实际应用

从数据库获取用户输入

def get_user_input(question: str, interrupt_id: str) -> str:
    import sqlite3
    conn = sqlite3.connect('user_inputs.db')
    cursor = conn.cursor()
    cursor.execute("SELECT answer FROM user_inputs WHERE interrupt_id = ?", (interrupt_id,))
    result = cursor.fetchone()
    conn.close()
    if result:
        return result[0]
    else:
        raise ValueError(f"未找到中断ID {interrupt_id} 对应的用户输入")

从 API 获取用户输入

def get_user_input(question: str, interrupt_id: str) -> str:
    import requests
    response = requests.post("https://api.example.com/get-user-input",
                             json={"interrupt_id": interrupt_id, "question": question})
    if response.status_code == 200:
        return response.json()["answer"]
    else:
        raise ValueError(f"API 调用失败: {response.status_code}")

从消息队列获取用户输入

def get_user_input(question: str, interrupt_id: str) -> str:
    import redis
    import time
    r = redis.Redis(host='localhost', port=6379, db=0)
    # 轮询等待用户输入
    while True:
        answer = r.get(f"user_input:{interrupt_id}")
        if answer:
            return answer.decode('utf-8')
        time.sleep(0.1)  # 等待100ms后重试

从 WebSocket 获取用户输入

def get_user_input(question: str, interrupt_id: str) -> str:
    import asyncio
    import websockets
    import json
    async def wait_for_input():
        async with websockets.connect("ws://localhost:8765") as websocket:
            await websocket.send(json.dumps({"interrupt_id": interrupt_id,
                                              "question": question}))
            response = await websocket.recv()
            return json.loads(response)["answer"]
    return asyncio.run(wait_for_input())

最佳实践

1. 监听中断事件

通过检查流事件中的 __interrupt__ 键来检测中断:

for event in app.stream(input, config, stream_mode="updates"):
    if "__interrupt__" in event:
        interrupt_info = event["__interrupt__"][0]
        # 处理中断

2. 处理多个中断

如果图执行过程中可能触发多次中断,建议用循环代替递归,避免无限递归的风险:

def run_with_auto_resume(app, initial_input, config, max_iterations=10):
    iteration = 0
    while iteration < max_iterations:
        iteration += 1
        interrupt_detected = False
        input_data = initial_input if iteration == 1 else Command(resume=user_answer)
        for event in app.stream(input_data, config, stream_mode="updates"):
            if "__interrupt__" in event:
                interrupt_info = event["__interrupt__"][0]
                user_answer = get_user_input(interrupt_info.value, interrupt_info.id)
                interrupt_detected = True
                break
        if not interrupt_detected:
            return events
    raise RuntimeError(f"达到最大迭代次数 {max_iterations}")

3. 错误处理

生产环境一定要加超时和异常处理,别让系统一直傻等:

def get_user_input(question: str, interrupt_id: str, timeout: float = 30.0) -> str:
    import time
    start_time = time.time()
    while time.time() - start_time < timeout:
        try:
            # 尝试获取用户输入
            answer = fetch_from_external_system(interrupt_id)
            if answer:
                return answer
        except Exception as e:
            logger.error(f"获取用户输入失败: {e}")
        time.sleep(0.1)
    raise TimeoutError(f"获取用户输入超时: {timeout}秒")

4. 状态管理

thread_id 管理不同用户的执行会话:

# 为每个用户创建独立的 thread_id
config = {"configurable": {"thread_id": f"user_{user_id}"}}

# 可以从特定检查点恢复
config = {"configurable": {"thread_id": f"user_{user_id}", "checkpoint_id": checkpoint_id}}
免责声明

本网站新闻资讯均来自公开渠道,力求准确但不保证绝对无误,内容观点仅代表作者本人,与本站无关。若涉及侵权,请联系我们处理。本站保留对声明的修改权,最终解释权归本站所有。

相关阅读

更多
欢迎回来 登录或注册后,可保存提示词和历史记录
登录后可同步收藏、历史记录和常用模板
注册即表示同意服务条款与隐私政策