Add list files tool to enable researcher to find documents by path

Allow getting a map of user's knowledge base under specified path.

This enables more thorough retrieval from user's knowledge base by
combining search, view and list files tools.
This commit is contained in:
Debanjum
2025-06-14 00:34:47 -07:00
parent 2f9f608cff
commit 59f5648dbd
4 changed files with 113 additions and 1 deletions

View File

@@ -1716,6 +1716,14 @@ class FileObjectAdapters:
async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
@staticmethod
@arequire_valid_user
async def aget_file_objects_by_path_prefix(user: KhojUser, path_prefix: str, agent: Agent = None):
"""Get file objects from the database by path prefix."""
return await sync_to_async(list)(
FileObject.objects.filter(user=user, agent=agent, file_name__startswith=path_prefix)
)
@staticmethod
@arequire_valid_user
async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):

View File

@@ -1,6 +1,7 @@
import asyncio
import base64
import concurrent.futures
import fnmatch
import hashlib
import json
import logging
@@ -2878,3 +2879,63 @@ async def view_file_content(
# Return an error result in the expected format
yield [{"query": query, "file": path, "compiled": error_msg}]
async def list_files(
path: Optional[str] = None,
pattern: Optional[str] = None,
user: KhojUser = None,
):
"""
List files under a given path or glob pattern from the user's document database.
"""
# Construct the query string based on provided parameters
def _generate_query(doc_count, path, pattern):
query = f"**Found {doc_count} files**"
if path:
query += f" in {path}"
if pattern:
query += f" filtered by {pattern}"
return query
try:
# Get user files by path prefix when specified
path = path or ""
if path in ["", "/", ".", "./", "~", "~/"]:
file_objects = await FileObjectAdapters.aget_all_file_objects(user, limit=10000)
else:
file_objects = await FileObjectAdapters.aget_file_objects_by_path_prefix(user, path)
if not file_objects:
yield {"query": _generate_query(0, path, pattern), "file": path, "compiled": "No files found."}
return
# Extract file names from file objects
files = [f.file_name for f in file_objects]
# Convert to relative file path (similar to ls)
if path:
files = [f[len(path) :] for f in files]
# Apply glob pattern filtering if specified
if pattern:
files = [f for f in files if fnmatch.fnmatch(f, pattern)]
query = _generate_query(len(files), path, pattern)
if not files:
yield {"query": query, "file": path, "compiled": "No files found."}
return
# Truncate the list if it's too long
max_files = 100
if len(files) > max_files:
files = files[:max_files] + [
f"... {len(files) - max_files} more files found. Use glob pattern to narrow down results."
]
yield {"query": query, "file": path, "compiled": "\n- ".join(files)}
except Exception as e:
error_msg = f"Error listing files in {path}: {str(e)}"
logger.error(error_msg, exc_info=True)
yield {"query": query, "file": path, "compiled": error_msg}

View File

@@ -24,6 +24,7 @@ from khoj.processor.tools.run_code import run_code
from khoj.routers.helpers import (
ChatEvent,
generate_summary_from_files,
list_files,
search_documents,
send_message_to_model_wrapper,
view_file_content,
@@ -91,7 +92,11 @@ async def apick_next_tool(
if tool == ConversationCommand.Operator and not is_operator_enabled():
continue
# Skip showing document related tools if user has no documents
if (tool == ConversationCommand.Notes or tool == ConversationCommand.ViewFile) and not user_has_entries:
if (
tool == ConversationCommand.Notes
or tool == ConversationCommand.ViewFile
or tool == ConversationCommand.ListFiles
) and not user_has_entries:
continue
# Skip showing Notes tool as an option if user has no entries
if tool == ConversationCommand.Notes:
@@ -447,6 +452,25 @@ async def research(
this_iteration.warning = f"Error viewing file: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ListFiles:
try:
async for result in list_files(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = [result] # type: ignore
this_iteration.context += document_results
async for result in send_status_func(result["query"]):
yield result
except Exception as e:
this_iteration.warning = f"Error listing files: {e}"
logger.error(this_iteration.warning, exc_info=True)
else:
# No valid tools. This is our exit condition.
current_iteration = MAX_ITERATIONS

View File

@@ -430,6 +430,7 @@ class ConversationCommand(str, Enum):
Research = "research"
Operator = "operator"
ViewFile = "view_file"
ListFiles = "list_files"
command_descriptions = {
@@ -444,6 +445,7 @@ command_descriptions = {
ConversationCommand.Research: "Do deep research on a topic. This will take longer than usual, but give a more detailed, comprehensive answer.",
ConversationCommand.Operator: "Operate and perform tasks using a computer.",
ConversationCommand.ViewFile: "View the contents of a file with optional line range specification.",
ConversationCommand.ListFiles: "List files under a given path with optional glob pattern.",
}
command_descriptions_for_agent = {
@@ -576,6 +578,23 @@ tools_for_research_llm = {
"required": ["path"],
},
),
ConversationCommand.ListFiles: ToolDefinition(
name="list_files",
description="To list files under a given path or glob pattern.",
schema={
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The directory path to list files from.",
},
"pattern": {
"type": "string",
"description": "Optional glob pattern to filter files (e.g., '*.md').",
},
},
},
),
}
mode_descriptions_for_llm = {