mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Support using MCP tools in research mode
- Server admin can add MCP servers via the admin panel - Enabled MCP server tools are exposed to the research agent for use - Use MCP library to standardize interactions with mcp servers - Support SSE or Stdio as transport to interact with mcp servers - Reuse session established to MCP servers across research iterations
This commit is contained in:
@@ -49,6 +49,7 @@ from khoj.database.models import (
|
||||
GoogleUser,
|
||||
KhojApiUser,
|
||||
KhojUser,
|
||||
McpServer,
|
||||
NotionConfig,
|
||||
PriceTier,
|
||||
ProcessLock,
|
||||
@@ -2127,3 +2128,15 @@ class AutomationAdapters:
|
||||
|
||||
automation.remove()
|
||||
return automation_metadata
|
||||
|
||||
|
||||
class McpServerAdapters:
|
||||
@staticmethod
|
||||
async def aget_all_mcp_servers() -> List[McpServer]:
|
||||
"""Asynchronously retrieve all McpServer objects from the database."""
|
||||
servers: List[McpServer] = []
|
||||
try:
|
||||
servers = [server async for server in McpServer.objects.all()]
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving MCP servers: {e}", exc_info=True)
|
||||
return servers
|
||||
|
||||
@@ -23,6 +23,7 @@ from khoj.database.models import (
|
||||
Entry,
|
||||
GithubConfig,
|
||||
KhojUser,
|
||||
McpServer,
|
||||
NotionConfig,
|
||||
ProcessLock,
|
||||
RateLimitRecord,
|
||||
@@ -183,6 +184,16 @@ admin.site.register(UserRequests, unfold_admin.ModelAdmin)
|
||||
admin.site.register(RateLimitRecord, unfold_admin.ModelAdmin)
|
||||
|
||||
|
||||
@admin.register(McpServer)
|
||||
class McpServerAdmin(unfold_admin.ModelAdmin):
|
||||
list_display = (
|
||||
"id",
|
||||
"name",
|
||||
"path",
|
||||
)
|
||||
search_fields = ("id", "name", "path")
|
||||
|
||||
|
||||
@admin.register(Agent)
|
||||
class AgentAdmin(unfold_admin.ModelAdmin):
|
||||
list_display = (
|
||||
|
||||
26
src/khoj/database/migrations/0096_mcpserver.py
Normal file
26
src/khoj/database/migrations/0096_mcpserver.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Generated by Django 5.1.14 on 2025-11-13 20:58
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0095_alter_webscraper_type"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="McpServer",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("name", models.CharField(max_length=50, unique=True)),
|
||||
("path", models.CharField(max_length=200, unique=True)),
|
||||
("api_key", models.CharField(blank=True, max_length=4000, null=True)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -806,3 +806,12 @@ class DataStore(DbBaseModel):
|
||||
value = models.JSONField(default=dict)
|
||||
private = models.BooleanField(default=False)
|
||||
owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class McpServer(DbBaseModel):
|
||||
name = models.CharField(max_length=50, unique=True)
|
||||
path = models.CharField(max_length=200, unique=True)
|
||||
api_key = models.CharField(max_length=4000, blank=True, null=True)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
124
src/khoj/processor/tools/mcp.py
Normal file
124
src/khoj/processor/tools/mcp.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import logging
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Optional
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.types import AudioContent, ImageContent, TextContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""
|
||||
A client for interacting with MCP servers.
|
||||
Establishes a session with the server and provides methods to get and run tools.
|
||||
Supports both stdio and sse communication methods.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, path: str, api_key: Optional[str] = None):
|
||||
self.name = name
|
||||
self.path = path
|
||||
self.api_key = api_key
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
Connect to the MCP server based on the specified path.
|
||||
Uses stdio for local scripts and SSE for remote servers.
|
||||
"""
|
||||
if self.path.startswith("http://") or self.path.startswith("https://"):
|
||||
# Path is a URL, so use SSE to connect to the server
|
||||
await self._connect_to_sse_server()
|
||||
else:
|
||||
# Path is a script, so use stdio to connect to the server
|
||||
await self._connect_to_stdio_server()
|
||||
|
||||
async def get_tools(self):
|
||||
"""
|
||||
Retrieve the list of tools available on the MCP server.
|
||||
"""
|
||||
# Ensure connection is established
|
||||
if not self.session:
|
||||
await self.connect()
|
||||
|
||||
tools_response = await self.session.list_tools()
|
||||
return tools_response.tools
|
||||
|
||||
async def run_tool(self, tool_name: str, input_data: dict):
|
||||
"""
|
||||
Run a specified tool on the MCP server with the given input data.
|
||||
"""
|
||||
# Ensure connection is established
|
||||
if not self.session:
|
||||
await self.connect()
|
||||
|
||||
result = await self.session.call_tool(tool_name, input_data)
|
||||
|
||||
# Process result content based on its type
|
||||
if len(result.content) > 0 and isinstance(result.content[0], TextContent):
|
||||
return [item.text for item in result.content]
|
||||
elif len(result.content) > 0 and isinstance(result.content[0], AudioContent):
|
||||
return [{"data": item.data, "format": item.mimeType} for item in result.content]
|
||||
elif len(result.content) > 0 and isinstance(result.content[0], ImageContent):
|
||||
return [{"data": item.data, "format": item.mimeType} for item in result.content]
|
||||
return result.content
|
||||
|
||||
async def _connect_to_sse_server(self):
|
||||
"""
|
||||
Connect to the MCP server using Server-Sent Events (SSE).
|
||||
"""
|
||||
self._streams_context = sse_client(url=self.path)
|
||||
streams = await self._streams_context.__aenter__()
|
||||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self.session = await self._session_context.__aenter__()
|
||||
|
||||
# Initialize
|
||||
await self.session.initialize()
|
||||
|
||||
async def _connect_to_stdio_server(self):
|
||||
"""
|
||||
Connect to the MCP server using stdio communication.
|
||||
"""
|
||||
is_python = False
|
||||
is_javascript = False
|
||||
command = None
|
||||
args = [self.path]
|
||||
|
||||
# Determine if the server is a file path or npm package
|
||||
if self.path.startswith("@") or "/" not in self.path:
|
||||
# Assume it's an npm package
|
||||
is_javascript = True
|
||||
command = "npx"
|
||||
else:
|
||||
# It's a file path
|
||||
is_python = self.path.endswith(".py")
|
||||
is_javascript = self.path.endswith(".js")
|
||||
if not (is_python or is_javascript):
|
||||
raise ValueError("Server script must be a .py, .js file or npm package.")
|
||||
|
||||
command = "python" if is_python else "node"
|
||||
|
||||
server_params = StdioServerParameters(command=command, args=args, env=None)
|
||||
|
||||
logger.debug(f"Connecting to stdio MCP server with command: {command} and args: {args}")
|
||||
|
||||
# Start the server
|
||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||
self.stdio, self.writer = stdio_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.writer))
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Close the MCP client session and clean up resources.
|
||||
"""
|
||||
await self.exit_stack.aclose()
|
||||
if hasattr(self, "_session_context") and self._session_context:
|
||||
await self._session_context.__aexit__(None, None, None)
|
||||
if hasattr(self, "_streams_context") and self._streams_context:
|
||||
await self._streams_context.__aexit__(None, None, None)
|
||||
@@ -7,7 +7,7 @@ from typing import Callable, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from khoj.database.adapters import AgentAdapters, EntryAdapters
|
||||
from khoj.database.adapters import AgentAdapters, EntryAdapters, McpServerAdapters
|
||||
from khoj.database.models import Agent, ChatMessageModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
@@ -19,6 +19,7 @@ from khoj.processor.conversation.utils import (
|
||||
load_complex_json,
|
||||
)
|
||||
from khoj.processor.operator import operate_environment
|
||||
from khoj.processor.tools.mcp import MCPClient
|
||||
from khoj.processor.tools.online_search import read_webpages_content, search_online
|
||||
from khoj.processor.tools.run_code import run_code
|
||||
from khoj.routers.helpers import (
|
||||
@@ -61,6 +62,7 @@ async def apick_next_tool(
|
||||
max_document_searches: int = 7,
|
||||
max_online_searches: int = 3,
|
||||
max_webpages_to_read: int = 3,
|
||||
mcp_clients: List[MCPClient] = [],
|
||||
send_status_func: Optional[Callable] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
@@ -144,6 +146,23 @@ async def apick_next_tool(
|
||||
)
|
||||
)
|
||||
|
||||
# Get MCP tools
|
||||
for mcp_client in mcp_clients:
|
||||
try:
|
||||
mcp_tools = await mcp_client.get_tools()
|
||||
for mcp_tool in mcp_tools:
|
||||
qualified_tool_name = f"{mcp_client.name}/{mcp_tool.name}"
|
||||
tool_options_str += f'- "{qualified_tool_name}": "{mcp_tool.description}"\n'
|
||||
tools.append(
|
||||
ToolDefinition(
|
||||
name=qualified_tool_name,
|
||||
description=mcp_tool.description,
|
||||
schema=mcp_tool.inputSchema,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'Failed to get tools from MCP server "{mcp_client.name}", so skipping: {e}.')
|
||||
|
||||
today = datetime.today()
|
||||
location_data = f"{location}" if location else "Unknown"
|
||||
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
||||
@@ -240,6 +259,10 @@ async def research(
|
||||
current_iteration = 0
|
||||
MAX_ITERATIONS = int(os.getenv("KHOJ_RESEARCH_ITERATIONS", 5))
|
||||
|
||||
# Construct MCP clients
|
||||
mcp_servers = await McpServerAdapters.aget_all_mcp_servers()
|
||||
mcp_clients = [MCPClient(server.name, server.path, server.api_key) for server in mcp_servers]
|
||||
|
||||
# Incorporate previous partial research into current research chat history
|
||||
research_conversation_history = [chat for chat in deepcopy(conversation_history) if chat.message]
|
||||
if current_iteration := len(previous_iterations) > 0:
|
||||
@@ -277,6 +300,7 @@ async def research(
|
||||
code_results: Dict = dict()
|
||||
document_results: List[Dict[str, str]] = []
|
||||
operator_results: OperatorRun = None
|
||||
mcp_results: List = []
|
||||
this_iteration = ResearchIteration(query=query)
|
||||
|
||||
async for result in apick_next_tool(
|
||||
@@ -293,6 +317,7 @@ async def research(
|
||||
max_document_searches=max_document_searches,
|
||||
max_online_searches=max_online_searches,
|
||||
max_webpages_to_read=max_webpages_to_read,
|
||||
mcp_clients=mcp_clients,
|
||||
send_status_func=send_status_func,
|
||||
tracer=tracer,
|
||||
):
|
||||
@@ -540,13 +565,41 @@ async def research(
|
||||
this_iteration.warning = f"Error searching with regex: {e}"
|
||||
logger.error(this_iteration.warning, exc_info=True)
|
||||
|
||||
elif "/" in this_iteration.query.name:
|
||||
try:
|
||||
# Identify MCP client to use
|
||||
server_name, tool_name = this_iteration.query.name.split("/", 1)
|
||||
mcp_client = next((client for client in mcp_clients if client.name == server_name), None)
|
||||
if not mcp_client:
|
||||
raise ValueError(f"Could not find MCP server with name {server_name}")
|
||||
|
||||
# Invoke tool on the identified MCP server
|
||||
mcp_results = await mcp_client.run_tool(tool_name, this_iteration.query.args)
|
||||
|
||||
# Record tool result in context
|
||||
if this_iteration.context is None:
|
||||
this_iteration.context = []
|
||||
this_iteration.context += mcp_results
|
||||
async for result in send_status_func(f"**Used MCP Tool**: {tool_name} on {mcp_client.name}"):
|
||||
yield result
|
||||
except Exception as e:
|
||||
this_iteration.warning = f"Error using MCP tool: {e}"
|
||||
logger.error(this_iteration.warning, exc_info=True)
|
||||
|
||||
else:
|
||||
# No valid tools. This is our exit condition.
|
||||
current_iteration = MAX_ITERATIONS
|
||||
|
||||
current_iteration += 1
|
||||
|
||||
if document_results or online_results or code_results or operator_results or this_iteration.warning:
|
||||
if (
|
||||
document_results
|
||||
or online_results
|
||||
or code_results
|
||||
or operator_results
|
||||
or mcp_results
|
||||
or this_iteration.warning
|
||||
):
|
||||
results_data = f"\n<iteration_{current_iteration}_results>"
|
||||
if document_results:
|
||||
results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
|
||||
@@ -558,6 +611,8 @@ async def research(
|
||||
results_data += (
|
||||
f"\n<browser_operator_results>\n{operator_results.response}\n</browser_operator_results>"
|
||||
)
|
||||
if mcp_results:
|
||||
results_data += f"\n<mcp_tool_results>\n{yaml.dump(mcp_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</mcp_tool_results>"
|
||||
if this_iteration.warning:
|
||||
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
|
||||
results_data += f"\n</iteration_{current_iteration}_results>"
|
||||
@@ -568,3 +623,7 @@ async def research(
|
||||
this_iteration.summarizedResult = this_iteration.summarizedResult or "Failed to get results."
|
||||
previous_iterations.append(this_iteration)
|
||||
yield this_iteration
|
||||
|
||||
# Close MCP client connections
|
||||
for mcp_client in mcp_clients:
|
||||
await mcp_client.close()
|
||||
|
||||
Reference in New Issue
Block a user