mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
Add list files tool to enable researcher to find documents by path
Allow getting a map of user's knowledge base under specified path. This enables more thorough retrieval from user's knowledge base by combining search, view and list files tools.
This commit is contained in:
@@ -1716,6 +1716,14 @@ class FileObjectAdapters:
|
|||||||
async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
|
async def aget_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
|
||||||
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
|
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@arequire_valid_user
|
||||||
|
async def aget_file_objects_by_path_prefix(user: KhojUser, path_prefix: str, agent: Agent = None):
|
||||||
|
"""Get file objects from the database by path prefix."""
|
||||||
|
return await sync_to_async(list)(
|
||||||
|
FileObject.objects.filter(user=user, agent=agent, file_name__startswith=path_prefix)
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@arequire_valid_user
|
@arequire_valid_user
|
||||||
async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):
|
async def aget_file_objects_by_names(user: KhojUser, file_names: List[str]):
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import fnmatch
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -2878,3 +2879,63 @@ async def view_file_content(
|
|||||||
|
|
||||||
# Return an error result in the expected format
|
# Return an error result in the expected format
|
||||||
yield [{"query": query, "file": path, "compiled": error_msg}]
|
yield [{"query": query, "file": path, "compiled": error_msg}]
|
||||||
|
|
||||||
|
|
||||||
|
async def list_files(
|
||||||
|
path: Optional[str] = None,
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
user: KhojUser = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List files under a given path or glob pattern from the user's document database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Construct the query string based on provided parameters
|
||||||
|
def _generate_query(doc_count, path, pattern):
|
||||||
|
query = f"**Found {doc_count} files**"
|
||||||
|
if path:
|
||||||
|
query += f" in {path}"
|
||||||
|
if pattern:
|
||||||
|
query += f" filtered by {pattern}"
|
||||||
|
return query
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get user files by path prefix when specified
|
||||||
|
path = path or ""
|
||||||
|
if path in ["", "/", ".", "./", "~", "~/"]:
|
||||||
|
file_objects = await FileObjectAdapters.aget_all_file_objects(user, limit=10000)
|
||||||
|
else:
|
||||||
|
file_objects = await FileObjectAdapters.aget_file_objects_by_path_prefix(user, path)
|
||||||
|
|
||||||
|
if not file_objects:
|
||||||
|
yield {"query": _generate_query(0, path, pattern), "file": path, "compiled": "No files found."}
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract file names from file objects
|
||||||
|
files = [f.file_name for f in file_objects]
|
||||||
|
# Convert to relative file path (similar to ls)
|
||||||
|
if path:
|
||||||
|
files = [f[len(path) :] for f in files]
|
||||||
|
|
||||||
|
# Apply glob pattern filtering if specified
|
||||||
|
if pattern:
|
||||||
|
files = [f for f in files if fnmatch.fnmatch(f, pattern)]
|
||||||
|
|
||||||
|
query = _generate_query(len(files), path, pattern)
|
||||||
|
if not files:
|
||||||
|
yield {"query": query, "file": path, "compiled": "No files found."}
|
||||||
|
return
|
||||||
|
|
||||||
|
# Truncate the list if it's too long
|
||||||
|
max_files = 100
|
||||||
|
if len(files) > max_files:
|
||||||
|
files = files[:max_files] + [
|
||||||
|
f"... {len(files) - max_files} more files found. Use glob pattern to narrow down results."
|
||||||
|
]
|
||||||
|
|
||||||
|
yield {"query": query, "file": path, "compiled": "\n- ".join(files)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error listing files in {path}: {str(e)}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
yield {"query": query, "file": path, "compiled": error_msg}
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from khoj.processor.tools.run_code import run_code
|
|||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
generate_summary_from_files,
|
generate_summary_from_files,
|
||||||
|
list_files,
|
||||||
search_documents,
|
search_documents,
|
||||||
send_message_to_model_wrapper,
|
send_message_to_model_wrapper,
|
||||||
view_file_content,
|
view_file_content,
|
||||||
@@ -91,7 +92,11 @@ async def apick_next_tool(
|
|||||||
if tool == ConversationCommand.Operator and not is_operator_enabled():
|
if tool == ConversationCommand.Operator and not is_operator_enabled():
|
||||||
continue
|
continue
|
||||||
# Skip showing document related tools if user has no documents
|
# Skip showing document related tools if user has no documents
|
||||||
if (tool == ConversationCommand.Notes or tool == ConversationCommand.ViewFile) and not user_has_entries:
|
if (
|
||||||
|
tool == ConversationCommand.Notes
|
||||||
|
or tool == ConversationCommand.ViewFile
|
||||||
|
or tool == ConversationCommand.ListFiles
|
||||||
|
) and not user_has_entries:
|
||||||
continue
|
continue
|
||||||
# Skip showing Notes tool as an option if user has no entries
|
# Skip showing Notes tool as an option if user has no entries
|
||||||
if tool == ConversationCommand.Notes:
|
if tool == ConversationCommand.Notes:
|
||||||
@@ -447,6 +452,25 @@ async def research(
|
|||||||
this_iteration.warning = f"Error viewing file: {e}"
|
this_iteration.warning = f"Error viewing file: {e}"
|
||||||
logger.error(this_iteration.warning, exc_info=True)
|
logger.error(this_iteration.warning, exc_info=True)
|
||||||
|
|
||||||
|
elif this_iteration.query.name == ConversationCommand.ListFiles:
|
||||||
|
try:
|
||||||
|
async for result in list_files(
|
||||||
|
**this_iteration.query.args,
|
||||||
|
user=user,
|
||||||
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
if this_iteration.context is None:
|
||||||
|
this_iteration.context = []
|
||||||
|
document_results: List[Dict[str, str]] = [result] # type: ignore
|
||||||
|
this_iteration.context += document_results
|
||||||
|
async for result in send_status_func(result["query"]):
|
||||||
|
yield result
|
||||||
|
except Exception as e:
|
||||||
|
this_iteration.warning = f"Error listing files: {e}"
|
||||||
|
logger.error(this_iteration.warning, exc_info=True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# No valid tools. This is our exit condition.
|
# No valid tools. This is our exit condition.
|
||||||
current_iteration = MAX_ITERATIONS
|
current_iteration = MAX_ITERATIONS
|
||||||
|
|||||||
@@ -430,6 +430,7 @@ class ConversationCommand(str, Enum):
|
|||||||
Research = "research"
|
Research = "research"
|
||||||
Operator = "operator"
|
Operator = "operator"
|
||||||
ViewFile = "view_file"
|
ViewFile = "view_file"
|
||||||
|
ListFiles = "list_files"
|
||||||
|
|
||||||
|
|
||||||
command_descriptions = {
|
command_descriptions = {
|
||||||
@@ -444,6 +445,7 @@ command_descriptions = {
|
|||||||
ConversationCommand.Research: "Do deep research on a topic. This will take longer than usual, but give a more detailed, comprehensive answer.",
|
ConversationCommand.Research: "Do deep research on a topic. This will take longer than usual, but give a more detailed, comprehensive answer.",
|
||||||
ConversationCommand.Operator: "Operate and perform tasks using a computer.",
|
ConversationCommand.Operator: "Operate and perform tasks using a computer.",
|
||||||
ConversationCommand.ViewFile: "View the contents of a file with optional line range specification.",
|
ConversationCommand.ViewFile: "View the contents of a file with optional line range specification.",
|
||||||
|
ConversationCommand.ListFiles: "List files under a given path with optional glob pattern.",
|
||||||
}
|
}
|
||||||
|
|
||||||
command_descriptions_for_agent = {
|
command_descriptions_for_agent = {
|
||||||
@@ -576,6 +578,23 @@ tools_for_research_llm = {
|
|||||||
"required": ["path"],
|
"required": ["path"],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
ConversationCommand.ListFiles: ToolDefinition(
|
||||||
|
name="list_files",
|
||||||
|
description="To list files under a given path or glob pattern.",
|
||||||
|
schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The directory path to list files from.",
|
||||||
|
},
|
||||||
|
"pattern": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional glob pattern to filter files (e.g., '*.md').",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
mode_descriptions_for_llm = {
|
mode_descriptions_for_llm = {
|
||||||
|
|||||||
Reference in New Issue
Block a user