Support using MCP tools in research mode

- Server admin can add MCP servers via the admin panel
- Enabled MCP server tools are exposed to the research agent for use
- Use MCP library to standardize interactions with mcp servers
  - Support SSE or Stdio as transport to interact with mcp servers
  - Reuse session established to MCP servers across research iterations
This commit is contained in:
Debanjum
2025-11-13 12:54:16 -08:00
parent 2ac7359092
commit 3496189618
8 changed files with 402 additions and 2 deletions

View File

@@ -49,6 +49,7 @@ from khoj.database.models import (
GoogleUser,
KhojApiUser,
KhojUser,
McpServer,
NotionConfig,
PriceTier,
ProcessLock,
@@ -2127,3 +2128,15 @@ class AutomationAdapters:
automation.remove()
return automation_metadata
class McpServerAdapters:
@staticmethod
async def aget_all_mcp_servers() -> List[McpServer]:
"""Asynchronously retrieve all McpServer objects from the database."""
servers: List[McpServer] = []
try:
servers = [server async for server in McpServer.objects.all()]
except Exception as e:
logger.error(f"Error retrieving MCP servers: {e}", exc_info=True)
return servers

View File

@@ -23,6 +23,7 @@ from khoj.database.models import (
Entry,
GithubConfig,
KhojUser,
McpServer,
NotionConfig,
ProcessLock,
RateLimitRecord,
@@ -183,6 +184,16 @@ admin.site.register(UserRequests, unfold_admin.ModelAdmin)
admin.site.register(RateLimitRecord, unfold_admin.ModelAdmin)
@admin.register(McpServer)
class McpServerAdmin(unfold_admin.ModelAdmin):
list_display = (
"id",
"name",
"path",
)
search_fields = ("id", "name", "path")
@admin.register(Agent)
class AgentAdmin(unfold_admin.ModelAdmin):
list_display = (

View File

@@ -0,0 +1,26 @@
# Generated by Django 5.1.14 on 2025-11-13 20:58
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0095_alter_webscraper_type"),
]
operations = [
migrations.CreateModel(
name="McpServer",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(max_length=50, unique=True)),
("path", models.CharField(max_length=200, unique=True)),
("api_key", models.CharField(blank=True, max_length=4000, null=True)),
],
options={
"abstract": False,
},
),
]

View File

@@ -806,3 +806,12 @@ class DataStore(DbBaseModel):
value = models.JSONField(default=dict)
private = models.BooleanField(default=False)
owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
class McpServer(DbBaseModel):
name = models.CharField(max_length=50, unique=True)
path = models.CharField(max_length=200, unique=True)
api_key = models.CharField(max_length=4000, blank=True, null=True)
def __str__(self):
return self.name

View File

@@ -0,0 +1,124 @@
import logging
from contextlib import AsyncExitStack
from typing import Optional
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.types import AudioContent, ImageContent, TextContent
logger = logging.getLogger(__name__)
class MCPClient:
"""
A client for interacting with MCP servers.
Establishes a session with the server and provides methods to get and run tools.
Supports both stdio and sse communication methods.
"""
def __init__(self, name: str, path: str, api_key: Optional[str] = None):
self.name = name
self.path = path
self.api_key = api_key
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
async def connect(self):
"""
Connect to the MCP server based on the specified path.
Uses stdio for local scripts and SSE for remote servers.
"""
if self.path.startswith("http://") or self.path.startswith("https://"):
# Path is a URL, so use SSE to connect to the server
await self._connect_to_sse_server()
else:
# Path is a script, so use stdio to connect to the server
await self._connect_to_stdio_server()
async def get_tools(self):
"""
Retrieve the list of tools available on the MCP server.
"""
# Ensure connection is established
if not self.session:
await self.connect()
tools_response = await self.session.list_tools()
return tools_response.tools
async def run_tool(self, tool_name: str, input_data: dict):
"""
Run a specified tool on the MCP server with the given input data.
"""
# Ensure connection is established
if not self.session:
await self.connect()
result = await self.session.call_tool(tool_name, input_data)
# Process result content based on its type
if len(result.content) > 0 and isinstance(result.content[0], TextContent):
return [item.text for item in result.content]
elif len(result.content) > 0 and isinstance(result.content[0], AudioContent):
return [{"data": item.data, "format": item.mimeType} for item in result.content]
elif len(result.content) > 0 and isinstance(result.content[0], ImageContent):
return [{"data": item.data, "format": item.mimeType} for item in result.content]
return result.content
async def _connect_to_sse_server(self):
"""
Connect to the MCP server using Server-Sent Events (SSE).
"""
self._streams_context = sse_client(url=self.path)
streams = await self._streams_context.__aenter__()
self._session_context = ClientSession(*streams)
self.session = await self._session_context.__aenter__()
# Initialize
await self.session.initialize()
async def _connect_to_stdio_server(self):
"""
Connect to the MCP server using stdio communication.
"""
is_python = False
is_javascript = False
command = None
args = [self.path]
# Determine if the server is a file path or npm package
if self.path.startswith("@") or "/" not in self.path:
# Assume it's an npm package
is_javascript = True
command = "npx"
else:
# It's a file path
is_python = self.path.endswith(".py")
is_javascript = self.path.endswith(".js")
if not (is_python or is_javascript):
raise ValueError("Server script must be a .py, .js file or npm package.")
command = "python" if is_python else "node"
server_params = StdioServerParameters(command=command, args=args, env=None)
logger.debug(f"Connecting to stdio MCP server with command: {command} and args: {args}")
# Start the server
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.writer = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.writer))
await self.session.initialize()
async def close(self):
"""
Close the MCP client session and clean up resources.
"""
await self.exit_stack.aclose()
if hasattr(self, "_session_context") and self._session_context:
await self._session_context.__aexit__(None, None, None)
if hasattr(self, "_streams_context") and self._streams_context:
await self._streams_context.__aexit__(None, None, None)

View File

@@ -7,7 +7,7 @@ from typing import Callable, Dict, List, Optional
import yaml
from khoj.database.adapters import AgentAdapters, EntryAdapters
from khoj.database.adapters import AgentAdapters, EntryAdapters, McpServerAdapters
from khoj.database.models import Agent, ChatMessageModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
@@ -19,6 +19,7 @@ from khoj.processor.conversation.utils import (
load_complex_json,
)
from khoj.processor.operator import operate_environment
from khoj.processor.tools.mcp import MCPClient
from khoj.processor.tools.online_search import read_webpages_content, search_online
from khoj.processor.tools.run_code import run_code
from khoj.routers.helpers import (
@@ -61,6 +62,7 @@ async def apick_next_tool(
max_document_searches: int = 7,
max_online_searches: int = 3,
max_webpages_to_read: int = 3,
mcp_clients: List[MCPClient] = [],
send_status_func: Optional[Callable] = None,
tracer: dict = {},
):
@@ -144,6 +146,23 @@ async def apick_next_tool(
)
)
# Get MCP tools
for mcp_client in mcp_clients:
try:
mcp_tools = await mcp_client.get_tools()
for mcp_tool in mcp_tools:
qualified_tool_name = f"{mcp_client.name}/{mcp_tool.name}"
tool_options_str += f'- "{qualified_tool_name}": "{mcp_tool.description}"\n'
tools.append(
ToolDefinition(
name=qualified_tool_name,
description=mcp_tool.description,
schema=mcp_tool.inputSchema,
)
)
except Exception as e:
logger.warning(f'Failed to get tools from MCP server "{mcp_client.name}", so skipping: {e}.')
today = datetime.today()
location_data = f"{location}" if location else "Unknown"
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
@@ -240,6 +259,10 @@ async def research(
current_iteration = 0
MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5))
# Construct MCP clients
mcp_servers = await McpServerAdapters.aget_all_mcp_servers()
mcp_clients = [MCPClient(server.name, server.path, server.api_key) for server in mcp_servers]
# Incorporate previous partial research into current research chat history
research_conversation_history = [chat for chat in deepcopy(conversation_history) if chat.message]
if current_iteration := len(previous_iterations) > 0:
@@ -277,6 +300,7 @@ async def research(
code_results: Dict = dict()
document_results: List[Dict[str, str]] = []
operator_results: OperatorRun = None
mcp_results: List = []
this_iteration = ResearchIteration(query=query)
async for result in apick_next_tool(
@@ -293,6 +317,7 @@ async def research(
max_document_searches=max_document_searches,
max_online_searches=max_online_searches,
max_webpages_to_read=max_webpages_to_read,
mcp_clients=mcp_clients,
send_status_func=send_status_func,
tracer=tracer,
):
@@ -540,13 +565,41 @@ async def research(
this_iteration.warning = f"Error searching with regex: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif "/" in this_iteration.query.name:
try:
# Identify MCP client to use
server_name, tool_name = this_iteration.query.name.split("/", 1)
mcp_client = next((client for client in mcp_clients if client.name == server_name), None)
if not mcp_client:
raise ValueError(f"Could not find MCP server with name {server_name}")
# Invoke tool on the identified MCP server
mcp_results = await mcp_client.run_tool(tool_name, this_iteration.query.args)
# Record tool result in context
if this_iteration.context is None:
this_iteration.context = []
this_iteration.context += mcp_results
async for result in send_status_func(f"**Used MCP Tool**: {tool_name} on {mcp_client.name}"):
yield result
except Exception as e:
this_iteration.warning = f"Error using MCP tool: {e}"
logger.error(this_iteration.warning, exc_info=True)
else:
# No valid tools. This is our exit condition.
current_iteration = MAX_ITERATIONS
current_iteration += 1
if document_results or online_results or code_results or operator_results or this_iteration.warning:
if (
document_results
or online_results
or code_results
or operator_results
or mcp_results
or this_iteration.warning
):
results_data = f"\n<iteration_{current_iteration}_results>"
if document_results:
results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
@@ -558,6 +611,8 @@ async def research(
results_data += (
f"\n<browser_operator_results>\n{operator_results.response}\n</browser_operator_results>"
)
if mcp_results:
results_data += f"\n<mcp_tool_results>\n{yaml.dump(mcp_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</mcp_tool_results>"
if this_iteration.warning:
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
results_data += f"\n</iteration_{current_iteration}_results>"
@@ -568,3 +623,7 @@ async def research(
this_iteration.summarizedResult = this_iteration.summarizedResult or "Failed to get results."
previous_iterations.append(this_iteration)
yield this_iteration
# Close MCP client connections
for mcp_client in mcp_clients:
await mcp_client.close()