mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Add list files tool to enable researcher to find documents by path
Allow getting a map of user's knowledge base under specified path. This enables more thorough retrieval from user's knowledge base by combining search, view and list files tools.
This commit is contained in:
@@ -1716,6 +1716,14 @@ class FileObjectAdapters:
|
||||
async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
|
||||
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aget_file_objects_by_path_prefix(user: KhojUser, path_prefix: str, agent: Agent = None):
|
||||
"""Get file objects from the database by path prefix."""
|
||||
return await sync_to_async(list)(
|
||||
FileObject.objects.filter(user=user, agent=agent, file_name__startswith=path_prefix)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import fnmatch
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
@@ -2878,3 +2879,63 @@ async def view_file_content(
|
||||
|
||||
# Return an error result in the expected format
|
||||
yield [{"query": query, "file": path, "compiled": error_msg}]
|
||||
|
||||
|
||||
async def list_files(
|
||||
path: Optional[str] = None,
|
||||
pattern: Optional[str] = None,
|
||||
user: KhojUser = None,
|
||||
):
|
||||
"""
|
||||
List files under a given path or glob pattern from the user's document database.
|
||||
"""
|
||||
|
||||
# Construct the query string based on provided parameters
|
||||
def _generate_query(doc_count, path, pattern):
|
||||
query = f"**Found {doc_count} files**"
|
||||
if path:
|
||||
query += f" in {path}"
|
||||
if pattern:
|
||||
query += f" filtered by {pattern}"
|
||||
return query
|
||||
|
||||
try:
|
||||
# Get user files by path prefix when specified
|
||||
path = path or ""
|
||||
if path in ["", "/", ".", "./", "~", "~/"]:
|
||||
file_objects = await FileObjectAdapters.aget_all_file_objects(user, limit=10000)
|
||||
else:
|
||||
file_objects = await FileObjectAdapters.aget_file_objects_by_path_prefix(user, path)
|
||||
|
||||
if not file_objects:
|
||||
yield {"query": _generate_query(0, path, pattern), "file": path, "compiled": "No files found."}
|
||||
return
|
||||
|
||||
# Extract file names from file objects
|
||||
files = [f.file_name for f in file_objects]
|
||||
# Convert to relative file path (similar to ls)
|
||||
if path:
|
||||
files = [f[len(path) :] for f in files]
|
||||
|
||||
# Apply glob pattern filtering if specified
|
||||
if pattern:
|
||||
files = [f for f in files if fnmatch.fnmatch(f, pattern)]
|
||||
|
||||
query = _generate_query(len(files), path, pattern)
|
||||
if not files:
|
||||
yield {"query": query, "file": path, "compiled": "No files found."}
|
||||
return
|
||||
|
||||
# Truncate the list if it's too long
|
||||
max_files = 100
|
||||
if len(files) > max_files:
|
||||
files = files[:max_files] + [
|
||||
f"... {len(files) - max_files} more files found. Use glob pattern to narrow down results."
|
||||
]
|
||||
|
||||
yield {"query": query, "file": path, "compiled": "\n- ".join(files)}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error listing files in {path}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
yield {"query": query, "file": path, "compiled": error_msg}
|
||||
|
||||
@@ -24,6 +24,7 @@ from khoj.processor.tools.run_code import run_code
|
||||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
generate_summary_from_files,
|
||||
list_files,
|
||||
search_documents,
|
||||
send_message_to_model_wrapper,
|
||||
view_file_content,
|
||||
@@ -91,7 +92,11 @@ async def apick_next_tool(
|
||||
if tool == ConversationCommand.Operator and not is_operator_enabled():
|
||||
continue
|
||||
# Skip showing document related tools if user has no documents
|
||||
if (tool == ConversationCommand.Notes or tool == ConversationCommand.ViewFile) and not user_has_entries:
|
||||
if (
|
||||
tool == ConversationCommand.Notes
|
||||
or tool == ConversationCommand.ViewFile
|
||||
or tool == ConversationCommand.ListFiles
|
||||
) and not user_has_entries:
|
||||
continue
|
||||
# Skip showing Notes tool as an option if user has no entries
|
||||
if tool == ConversationCommand.Notes:
|
||||
@@ -447,6 +452,25 @@ async def research(
|
||||
this_iteration.warning = f"Error viewing file: {e}"
|
||||
logger.error(this_iteration.warning, exc_info=True)
|
||||
|
||||
elif this_iteration.query.name == ConversationCommand.ListFiles:
|
||||
try:
|
||||
async for result in list_files(
|
||||
**this_iteration.query.args,
|
||||
user=user,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
if this_iteration.context is None:
|
||||
this_iteration.context = []
|
||||
document_results: List[Dict[str, str]] = [result] # type: ignore
|
||||
this_iteration.context += document_results
|
||||
async for result in send_status_func(result["query"]):
|
||||
yield result
|
||||
except Exception as e:
|
||||
this_iteration.warning = f"Error listing files: {e}"
|
||||
logger.error(this_iteration.warning, exc_info=True)
|
||||
|
||||
else:
|
||||
# No valid tools. This is our exit condition.
|
||||
current_iteration = MAX_ITERATIONS
|
||||
|
||||
@@ -430,6 +430,7 @@ class ConversationCommand(str, Enum):
|
||||
Research = "research"
|
||||
Operator = "operator"
|
||||
ViewFile = "view_file"
|
||||
ListFiles = "list_files"
|
||||
|
||||
|
||||
command_descriptions = {
|
||||
@@ -444,6 +445,7 @@ command_descriptions = {
|
||||
ConversationCommand.Research: "Do deep research on a topic. This will take longer than usual, but give a more detailed, comprehensive answer.",
|
||||
ConversationCommand.Operator: "Operate and perform tasks using a computer.",
|
||||
ConversationCommand.ViewFile: "View the contents of a file with optional line range specification.",
|
||||
ConversationCommand.ListFiles: "List files under a given path with optional glob pattern.",
|
||||
}
|
||||
|
||||
command_descriptions_for_agent = {
|
||||
@@ -576,6 +578,23 @@ tools_for_research_llm = {
|
||||
"required": ["path"],
|
||||
},
|
||||
),
|
||||
ConversationCommand.ListFiles: ToolDefinition(
|
||||
name="list_files",
|
||||
description="To list files under a given path or glob pattern.",
|
||||
schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The directory path to list files from.",
|
||||
},
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Optional glob pattern to filter files (e.g., '*.md').",
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
mode_descriptions_for_llm = {
|
||||
|
||||
Reference in New Issue
Block a user