Format server code with ruff recommendations

This commit is contained in:
Debanjum
2025-08-01 00:10:34 -07:00
parent 4a3ed9e5a4
commit c8e07e86e4
65 changed files with 407 additions and 370 deletions

View File

@@ -14,6 +14,7 @@ Including another URLconf
1. Import the include() function: from django.urls import include, path 1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) 2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
""" """
from django.contrib import admin from django.contrib import admin
from django.contrib.staticfiles.urls import staticfiles_urlpatterns from django.contrib.staticfiles.urls import staticfiles_urlpatterns
from django.urls import path from django.urls import path

View File

@@ -1910,9 +1910,9 @@ class EntryAdapters:
owner_filter = Q() owner_filter = Q()
if user != None: if user is not None:
owner_filter = Q(user=user) owner_filter = Q(user=user)
if agent != None: if agent is not None:
owner_filter |= Q(agent=agent) owner_filter |= Q(agent=agent)
if owner_filter == Q(): if owner_filter == Q():
@@ -1972,9 +1972,9 @@ class EntryAdapters:
): ):
owner_filter = Q() owner_filter = Q()
if user != None: if user is not None:
owner_filter = Q(user=user) owner_filter = Q(user=user)
if agent != None: if agent is not None:
owner_filter |= Q(agent=agent) owner_filter |= Q(agent=agent)
if owner_filter == Q(): if owner_filter == Q():

View File

@@ -1,5 +1,4 @@
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.db import transaction
from django.db.models import Exists, OuterRef from django.db.models import Exists, OuterRef
from khoj.database.models import Entry, FileObject from khoj.database.models import Entry, FileObject

View File

@@ -41,7 +41,7 @@ def update_conversation_id_in_job_state(apps, schema_editor):
job.save() job.save()
except Conversation.DoesNotExist: except Conversation.DoesNotExist:
pass pass
except LookupError as e: except LookupError:
pass pass

View File

@@ -1,6 +1,6 @@
# Made manually by sabaimran for use by Django 5.0.9 on 2024-12-01 16:59 # Made manually by sabaimran for use by Django 5.0.9 on 2024-12-01 16:59
from django.db import migrations, models from django.db import migrations
# This script was written alongside when Pydantic validation was added to the Conversation conversation_log field. # This script was written alongside when Pydantic validation was added to the Conversation conversation_log field.

View File

@@ -551,12 +551,12 @@ class TextToImageModelConfig(DbBaseModel):
error = {} error = {}
if self.model_type == self.ModelType.OPENAI: if self.model_type == self.ModelType.OPENAI:
if self.api_key and self.ai_model_api: if self.api_key and self.ai_model_api:
error[ error["api_key"] = (
"api_key" "Both API key and AI Model API cannot be set for OpenAI models. Please set only one of them."
] = "Both API key and AI Model API cannot be set for OpenAI models. Please set only one of them." )
error[ error["ai_model_api"] = (
"ai_model_api" "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them." )
if self.model_type != self.ModelType.OPENAI and self.model_type != self.ModelType.GOOGLE: if self.model_type != self.ModelType.OPENAI and self.model_type != self.ModelType.GOOGLE:
if not self.api_key: if not self.api_key:
error["api_key"] = "The API key field must be set for non OpenAI, non Google models." error["api_key"] = "The API key field must be set for non OpenAI, non Google models."

View File

@@ -1,3 +1 @@
from django.test import TestCase
# Create your tests here. # Create your tests here.

View File

@@ -1,5 +1,5 @@
""" Main module for Khoj """Main module for Khoj
isort:skip_file isort:skip_file
""" """
from contextlib import redirect_stdout from contextlib import redirect_stdout
@@ -189,7 +189,7 @@ def run(should_start_server=True):
static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static") static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static")
if not os.path.exists(static_dir): if not os.path.exists(static_dir):
os.mkdir(static_dir) os.mkdir(static_dir)
app.mount(f"/static", StaticFiles(directory=static_dir), name=static_dir) app.mount("/static", StaticFiles(directory=static_dir), name=static_dir)
# Configure Middleware # Configure Middleware
configure_middleware(app, state.ssl_config) configure_middleware(app, state.ssl_config)

View File

@@ -1,5 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
"""Django's command-line utility for administrative tasks.""" """Django's command-line utility for administrative tasks."""
import os import os
import sys import sys

View File

@@ -51,7 +51,7 @@ class GithubToEntries(TextToEntries):
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
if is_none_or_empty(self.config.pat_token): if is_none_or_empty(self.config.pat_token):
logger.warning( logger.warning(
f"Github PAT token is not set. Private repositories cannot be indexed and lower rate limits apply." "Github PAT token is not set. Private repositories cannot be indexed and lower rate limits apply."
) )
current_entries = [] current_entries = []
for repo in self.config.repos: for repo in self.config.repos:
@@ -137,7 +137,7 @@ class GithubToEntries(TextToEntries):
# Find all markdown files in the repository # Find all markdown files in the repository
if item["type"] == "blob" and item["path"].endswith(".md"): if item["type"] == "blob" and item["path"].endswith(".md"):
# Create URL for each markdown file on Github # Create URL for each markdown file on Github
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}' url_path = f"https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item['path']}"
# Add markdown file contents and URL to list # Add markdown file contents and URL to list
markdown_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}] markdown_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}]
@@ -145,19 +145,19 @@ class GithubToEntries(TextToEntries):
# Find all org files in the repository # Find all org files in the repository
elif item["type"] == "blob" and item["path"].endswith(".org"): elif item["type"] == "blob" and item["path"].endswith(".org"):
# Create URL for each org file on Github # Create URL for each org file on Github
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}' url_path = f"https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item['path']}"
# Add org file contents and URL to list # Add org file contents and URL to list
org_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}] org_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}]
# Find, index remaining non-binary files in the repository # Find, index remaining non-binary files in the repository
elif item["type"] == "blob": elif item["type"] == "blob":
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}' url_path = f"https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item['path']}"
content_bytes = self.get_file_contents(item["url"], decode=False) content_bytes = self.get_file_contents(item["url"], decode=False)
content_type, content_str = None, None content_type, content_str = None, None
try: try:
content_type = magika.identify_bytes(content_bytes).output.group content_type = magika.identify_bytes(content_bytes).output.group
except: except Exception:
logger.error(f"Unable to identify content type of file at {url_path}. Skip indexing it") logger.error(f"Unable to identify content type of file at {url_path}. Skip indexing it")
continue continue
@@ -165,7 +165,7 @@ class GithubToEntries(TextToEntries):
if content_type in ["text", "code"]: if content_type in ["text", "code"]:
try: try:
content_str = content_bytes.decode("utf-8") content_str = content_bytes.decode("utf-8")
except: except Exception:
logger.error(f"Unable to decode content of file at {url_path}. Skip indexing it") logger.error(f"Unable to decode content of file at {url_path}. Skip indexing it")
continue continue
plaintext_files += [{"content": content_str, "path": url_path}] plaintext_files += [{"content": content_str, "path": url_path}]

View File

@@ -1,4 +1,3 @@
import base64
import logging import logging
import os import os
from datetime import datetime, timezone from datetime import datetime, timezone

View File

@@ -1,6 +1,5 @@
import logging import logging
import re import re
from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import urllib3.util import urllib3.util
@@ -86,7 +85,7 @@ class MarkdownToEntries(TextToEntries):
# If content is small or content has no children headings, save it as a single entry # If content is small or content has no children headings, save it as a single entry
if len(TextToEntries.tokenizer(markdown_content_with_ancestry)) <= max_tokens or not re.search( if len(TextToEntries.tokenizer(markdown_content_with_ancestry)) <= max_tokens or not re.search(
rf"^#{{{len(ancestry)+1},}}\s", markdown_content, flags=re.MULTILINE rf"^#{{{len(ancestry) + 1},}}\s", markdown_content, flags=re.MULTILINE
): ):
# Create entry with line number information # Create entry with line number information
entry_with_line_info = (markdown_content_with_ancestry, markdown_file, start_line) entry_with_line_info = (markdown_content_with_ancestry, markdown_file, start_line)
@@ -160,7 +159,7 @@ class MarkdownToEntries(TextToEntries):
calculated_line = start_line if start_line > 0 else 1 calculated_line = start_line if start_line > 0 else 1
# Check if raw_filename is a URL. If so, save it as is. If not, convert it to a Path. # Check if raw_filename is a URL. If so, save it as is. If not, convert it to a Path.
if type(raw_filename) == str and re.search(r"^https?://", raw_filename): if isinstance(raw_filename, str) and re.search(r"^https?://", raw_filename):
# Escape the URL to avoid issues with special characters # Escape the URL to avoid issues with special characters
entry_filename = urllib3.util.parse_url(raw_filename).url entry_filename = urllib3.util.parse_url(raw_filename).url
uri = entry_filename uri = entry_filename

View File

@@ -91,7 +91,7 @@ class NotionToEntries(TextToEntries):
json=self.body_params, json=self.body_params,
).json() ).json()
responses.append(result) responses.append(result)
if result.get("has_more", False) == False: if not result.get("has_more", False):
break break
else: else:
self.body_params.update({"start_cursor": result["next_cursor"]}) self.body_params.update({"start_cursor": result["next_cursor"]})
@@ -118,7 +118,7 @@ class NotionToEntries(TextToEntries):
page_id = page["id"] page_id = page["id"]
title, content = self.get_page_content(page_id) title, content = self.get_page_content(page_id)
if title == None or content == None: if title is None or content is None:
return [] return []
current_entries = [] current_entries = []
@@ -126,11 +126,11 @@ class NotionToEntries(TextToEntries):
for block in content.get("results", []): for block in content.get("results", []):
block_type = block.get("type") block_type = block.get("type")
if block_type == None: if block_type is None:
continue continue
block_data = block[block_type] block_data = block[block_type]
if block_data.get("rich_text") == None or len(block_data["rich_text"]) == 0: if block_data.get("rich_text") is None or len(block_data["rich_text"]) == 0:
# There's no text to handle here. # There's no text to handle here.
continue continue
@@ -179,7 +179,7 @@ class NotionToEntries(TextToEntries):
results = children.get("results", []) results = children.get("results", [])
for child in results: for child in results:
child_type = child.get("type") child_type = child.get("type")
if child_type == None: if child_type is None:
continue continue
child_data = child[child_type] child_data = child[child_type]
if child_data.get("rich_text") and len(child_data["rich_text"]) > 0: if child_data.get("rich_text") and len(child_data["rich_text"]) > 0:

View File

@@ -8,7 +8,6 @@ from khoj.database.models import KhojUser
from khoj.processor.content.org_mode import orgnode from khoj.processor.content.org_mode import orgnode
from khoj.processor.content.org_mode.orgnode import Orgnode from khoj.processor.content.org_mode.orgnode import Orgnode
from khoj.processor.content.text_to_entries import TextToEntries from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils import state
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
@@ -103,7 +102,7 @@ class OrgToEntries(TextToEntries):
# If content is small or content has no children headings, save it as a single entry # If content is small or content has no children headings, save it as a single entry
# Note: This is the terminating condition for this recursive function # Note: This is the terminating condition for this recursive function
if len(TextToEntries.tokenizer(org_content_with_ancestry)) <= max_tokens or not re.search( if len(TextToEntries.tokenizer(org_content_with_ancestry)) <= max_tokens or not re.search(
rf"^\*{{{len(ancestry)+1},}}\s", org_content, re.MULTILINE rf"^\*{{{len(ancestry) + 1},}}\s", org_content, re.MULTILINE
): ):
orgnode_content_with_ancestry = orgnode.makelist( orgnode_content_with_ancestry = orgnode.makelist(
org_content_with_ancestry, org_file, start_line=start_line, ancestry_lines=len(ancestry) org_content_with_ancestry, org_file, start_line=start_line, ancestry_lines=len(ancestry)
@@ -195,7 +194,7 @@ class OrgToEntries(TextToEntries):
if not entry_heading and parsed_entry.level > 0: if not entry_heading and parsed_entry.level > 0:
base_level = parsed_entry.level base_level = parsed_entry.level
# Indent entry by 1 heading level as ancestry is prepended as top level heading # Indent entry by 1 heading level as ancestry is prepended as top level heading
heading = f"{'*' * (parsed_entry.level-base_level+2)} {todo_str}" if parsed_entry.level > 0 else "" heading = f"{'*' * (parsed_entry.level - base_level + 2)} {todo_str}" if parsed_entry.level > 0 else ""
if parsed_entry.heading: if parsed_entry.heading:
heading += f"{parsed_entry.heading}." heading += f"{parsed_entry.heading}."
@@ -212,10 +211,10 @@ class OrgToEntries(TextToEntries):
compiled += f"\t {tags_str}." compiled += f"\t {tags_str}."
if parsed_entry.closed: if parsed_entry.closed:
compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.' compiled += f"\n Closed on {parsed_entry.closed.strftime('%Y-%m-%d')}."
if parsed_entry.scheduled: if parsed_entry.scheduled:
compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.' compiled += f"\n Scheduled for {parsed_entry.scheduled.strftime('%Y-%m-%d')}."
if parsed_entry.hasBody: if parsed_entry.hasBody:
compiled += f"\n {parsed_entry.body}" compiled += f"\n {parsed_entry.body}"

View File

@@ -65,7 +65,7 @@ def makelist(file, filename, start_line: int = 1, ancestry_lines: int = 0) -> Li
""" """
ctr = 0 ctr = 0
if type(file) == str: if isinstance(file, str):
f = file.splitlines() f = file.splitlines()
else: else:
f = file f = file
@@ -512,11 +512,11 @@ class Orgnode(object):
if self._closed or self._scheduled or self._deadline: if self._closed or self._scheduled or self._deadline:
n = n + indent n = n + indent
if self._closed: if self._closed:
n = n + f'CLOSED: [{self._closed.strftime("%Y-%m-%d %a")}] ' n = n + f"CLOSED: [{self._closed.strftime('%Y-%m-%d %a')}] "
if self._scheduled: if self._scheduled:
n = n + f'SCHEDULED: <{self._scheduled.strftime("%Y-%m-%d %a")}> ' n = n + f"SCHEDULED: <{self._scheduled.strftime('%Y-%m-%d %a')}> "
if self._deadline: if self._deadline:
n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> ' n = n + f"DEADLINE: <{self._deadline.strftime('%Y-%m-%d %a')}> "
if self._closed or self._scheduled or self._deadline: if self._closed or self._scheduled or self._deadline:
n = n + "\n" n = n + "\n"

View File

@@ -1,6 +1,5 @@
import logging import logging
import re import re
from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import urllib3 import urllib3
@@ -97,7 +96,7 @@ class PlaintextToEntries(TextToEntries):
for parsed_entry in parsed_entries: for parsed_entry in parsed_entries:
raw_filename = entry_to_file_map[parsed_entry] raw_filename = entry_to_file_map[parsed_entry]
# Check if raw_filename is a URL. If so, save it as is. If not, convert it to a Path. # Check if raw_filename is a URL. If so, save it as is. If not, convert it to a Path.
if type(raw_filename) == str and re.search(r"^https?://", raw_filename): if isinstance(raw_filename, str) and re.search(r"^https?://", raw_filename):
# Escape the URL to avoid issues with special characters # Escape the URL to avoid issues with special characters
entry_filename = urllib3.util.parse_url(raw_filename).url entry_filename = urllib3.util.parse_url(raw_filename).url
else: else:

View File

@@ -30,8 +30,7 @@ class TextToEntries(ABC):
self.date_filter = DateFilter() self.date_filter = DateFilter()
@abstractmethod @abstractmethod
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: ...
...
@staticmethod @staticmethod
def hash_func(key: str) -> Callable: def hash_func(key: str) -> Callable:

View File

@@ -194,7 +194,7 @@ def gemini_completion_with_backoff(
or not response.candidates[0].content or not response.candidates[0].content
or response.candidates[0].content.parts is None or response.candidates[0].content.parts is None
): ):
raise ValueError(f"Failed to get response from model.") raise ValueError("Failed to get response from model.")
raw_content = [part.model_dump() for part in response.candidates[0].content.parts] raw_content = [part.model_dump() for part in response.candidates[0].content.parts]
if response.function_calls: if response.function_calls:
function_calls = [ function_calls = [
@@ -212,7 +212,7 @@ def gemini_completion_with_backoff(
response = None response = None
# Handle 429 rate limit errors directly # Handle 429 rate limit errors directly
if e.code == 429: if e.code == 429:
response_text = f"My brain is exhausted. Can you please try again in a bit?" response_text = "My brain is exhausted. Can you please try again in a bit?"
# Log the full error details for debugging # Log the full error details for debugging
logger.error(f"Gemini ClientError: {e.code} {e.status}. Details: {e.details}") logger.error(f"Gemini ClientError: {e.code} {e.status}. Details: {e.details}")
# Handle other errors # Handle other errors
@@ -361,7 +361,7 @@ def handle_gemini_response(
# Ensure we have a proper list of candidates # Ensure we have a proper list of candidates
if not isinstance(candidates, list): if not isinstance(candidates, list):
message = f"\nUnexpected response format. Try again." message = "\nUnexpected response format. Try again."
stopped = True stopped = True
return message, stopped return message, stopped

View File

@@ -2,7 +2,6 @@ import json
import logging import logging
import os import os
from copy import deepcopy from copy import deepcopy
from functools import partial
from time import perf_counter from time import perf_counter
from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -284,9 +283,9 @@ async def chat_completion_with_backoff(
if len(system_messages) > 0: if len(system_messages) > 0:
first_system_message_index, first_system_message = system_messages[0] first_system_message_index, first_system_message = system_messages[0]
first_system_message_content = first_system_message["content"] first_system_message_content = first_system_message["content"]
formatted_messages[first_system_message_index][ formatted_messages[first_system_message_index]["content"] = (
"content" f"{first_system_message_content}\nFormatting re-enabled"
] = f"{first_system_message_content}\nFormatting re-enabled" )
elif is_twitter_reasoning_model(model_name, api_base_url): elif is_twitter_reasoning_model(model_name, api_base_url):
reasoning_effort = "high" if deepthought else "low" reasoning_effort = "high" if deepthought else "low"
# Grok-4 models do not support reasoning_effort parameter # Grok-4 models do not support reasoning_effort parameter

View File

@@ -1,7 +1,6 @@
import base64 import base64
import json import json
import logging import logging
import math
import mimetypes import mimetypes
import os import os
import re import re
@@ -18,7 +17,7 @@ import requests
import tiktoken import tiktoken
import yaml import yaml
from langchain_core.messages.chat import ChatMessage from langchain_core.messages.chat import ChatMessage
from pydantic import BaseModel, ConfigDict, ValidationError, create_model from pydantic import BaseModel, ConfigDict, ValidationError
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from khoj.database.adapters import ConversationAdapters from khoj.database.adapters import ConversationAdapters
@@ -47,7 +46,11 @@ from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
from git import Repo import importlib.util
git_spec = importlib.util.find_spec("git")
if git_spec is None:
raise ImportError
except ImportError: except ImportError:
if is_promptrace_enabled(): if is_promptrace_enabled():
logger.warning("GitPython not installed. `pip install gitpython` to use prompt tracer.") logger.warning("GitPython not installed. `pip install gitpython` to use prompt tracer.")
@@ -294,7 +297,7 @@ def construct_chat_history_for_operator(conversation_history: List[ChatMessageMo
if chat.by == "you" and chat.message: if chat.by == "you" and chat.message:
content = [{"type": "text", "text": chat.message}] content = [{"type": "text", "text": chat.message}]
for file in chat.queryFiles or []: for file in chat.queryFiles or []:
content += [{"type": "text", "text": f'## File: {file["name"]}\n\n{file["content"]}'}] content += [{"type": "text", "text": f"## File: {file['name']}\n\n{file['content']}"}]
user_message = AgentMessage(role="user", content=content) user_message = AgentMessage(role="user", content=content)
elif chat.by == "khoj" and chat.message: elif chat.by == "khoj" and chat.message:
chat_history += [user_message, AgentMessage(role="assistant", content=chat.message)] chat_history += [user_message, AgentMessage(role="assistant", content=chat.message)]
@@ -311,7 +314,10 @@ def construct_tool_chat_history(
If no tool is provided inferred query for all tools used are added. If no tool is provided inferred query for all tools used are added.
""" """
chat_history: list = [] chat_history: list = []
base_extractor: Callable[[ResearchIteration], List[str]] = lambda iteration: []
def base_extractor(iteration: ResearchIteration) -> List[str]:
return []
extract_inferred_query_map: Dict[ConversationCommand, Callable[[ResearchIteration], List[str]]] = { extract_inferred_query_map: Dict[ConversationCommand, Callable[[ResearchIteration], List[str]]] = {
ConversationCommand.SemanticSearchFiles: ( ConversationCommand.SemanticSearchFiles: (
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else [] lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
@@ -498,7 +504,7 @@ async def save_to_conversation_log(
logger.info( logger.info(
f""" f"""
Saved Conversation Turn ({db_conversation.id if db_conversation else 'N/A'}): Saved Conversation Turn ({db_conversation.id if db_conversation else "N/A"}):
You ({user.username}): "{q}" You ({user.username}): "{q}"
Khoj: "{chat_response}" Khoj: "{chat_response}"
@@ -625,7 +631,7 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(chat.operatorContext): if not is_none_or_empty(chat.operatorContext):
operator_context = chat.operatorContext operator_context = chat.operatorContext
operator_content = "\n\n".join([f'## Task: {oc["query"]}\n{oc["response"]}\n' for oc in operator_context]) operator_content = "\n\n".join([f"## Task: {oc['query']}\n{oc['response']}\n" for oc in operator_context])
message_context += [ message_context += [
{ {
"type": "text", "type": "text",
@@ -744,7 +750,7 @@ def get_encoder(
else: else:
# as tiktoken doesn't recognize o1 model series yet # as tiktoken doesn't recognize o1 model series yet
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name) encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
except: except Exception:
encoder = tiktoken.encoding_for_model(default_tokenizer) encoder = tiktoken.encoding_for_model(default_tokenizer)
if state.verbose > 2: if state.verbose > 2:
logger.debug( logger.debug(
@@ -846,9 +852,9 @@ def truncate_messages(
total_tokens, _ = count_total_tokens(messages, encoder, system_message) total_tokens, _ = count_total_tokens(messages, encoder, system_message)
if total_tokens > max_prompt_size: if total_tokens > max_prompt_size:
# At this point, a single message with a single content part of type dict should remain # At this point, a single message with a single content part of type dict should remain
assert ( assert len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict), (
len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict) "Expected a single message with a single content part remaining at this point in truncation"
), "Expected a single message with a single content part remaining at this point in truncation" )
# Collate message content into single string to ease truncation # Collate message content into single string to ease truncation
part = messages[0].content[0] part = messages[0].content[0]

View File

@@ -1,8 +1,6 @@
import logging import logging
from typing import List from typing import List
from urllib.parse import urlparse
import openai
import requests import requests
import tqdm import tqdm
from sentence_transformers import CrossEncoder, SentenceTransformer from sentence_transformers import CrossEncoder, SentenceTransformer

View File

@@ -108,12 +108,12 @@ async def text_to_image(
if "content_policy_violation" in e.message: if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}") logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore message = "Image generation blocked by OpenAI due to policy violation" # type: ignore
yield image_url or image, status_code, message yield image_url or image, status_code, message
return return
else: else:
logger.error(f"Image Generation failed with {e}", exc_info=True) logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed using OpenAI" # type: ignore message = "Image generation failed using OpenAI" # type: ignore
status_code = e.status_code # type: ignore status_code = e.status_code # type: ignore
yield image_url or image, status_code, message yield image_url or image, status_code, message
return return
@@ -199,7 +199,7 @@ def generate_image_with_stability(
# Call Stability AI API to generate image # Call Stability AI API to generate image
response = requests.post( response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3", "https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"}, headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
files={"none": ""}, files={"none": ""},
data={ data={

View File

@@ -11,7 +11,7 @@ from khoj.processor.conversation.utils import (
OperatorRun, OperatorRun,
construct_chat_history_for_operator, construct_chat_history_for_operator,
) )
from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_actions import RequestUserAction
from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent
from khoj.processor.operator.operator_agent_base import OperatorAgent from khoj.processor.operator.operator_agent_base import OperatorAgent
from khoj.processor.operator.operator_agent_binary import BinaryOperatorAgent from khoj.processor.operator.operator_agent_binary import BinaryOperatorAgent
@@ -59,7 +59,7 @@ async def operate_environment(
if not reasoning_model or not reasoning_model.vision_enabled: if not reasoning_model or not reasoning_model.vision_enabled:
reasoning_model = await ConversationAdapters.aget_vision_enabled_config() reasoning_model = await ConversationAdapters.aget_vision_enabled_config()
if not reasoning_model: if not reasoning_model:
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.") raise ValueError("No vision enabled chat model found. Configure a vision chat model to operate environment.")
# Create conversation history from conversation log # Create conversation history from conversation log
chat_history = construct_chat_history_for_operator(conversation_log) chat_history = construct_chat_history_for_operator(conversation_log)

View File

@@ -1,14 +1,27 @@
import json import json
import logging import logging
from textwrap import dedent from textwrap import dedent
from typing import List, Optional
from openai import AzureOpenAI, OpenAI from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat import ChatCompletion, ChatCompletionMessage
from khoj.database.models import ChatModel from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import construct_structured_message from khoj.processor.conversation.utils import construct_structured_message
from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_actions import (
from khoj.processor.operator.operator_agent_base import AgentActResult BackAction,
ClickAction,
DoubleClickAction,
DragAction,
GotoAction,
KeypressAction,
OperatorAction,
Point,
ScreenshotAction,
ScrollAction,
TypeAction,
WaitAction,
)
from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState
from khoj.utils.helpers import get_chat_usage_metrics from khoj.utils.helpers import get_chat_usage_metrics

View File

@@ -18,7 +18,22 @@ from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletion from openai.types.chat import ChatCompletion
from PIL import Image from PIL import Image
from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_actions import (
BackAction,
ClickAction,
DoubleClickAction,
DragAction,
GotoAction,
KeyDownAction,
KeypressAction,
KeyUpAction,
MoveAction,
OperatorAction,
RequestUserAction,
ScrollAction,
TypeAction,
WaitAction,
)
from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState
from khoj.utils.helpers import get_chat_usage_metrics from khoj.utils.helpers import get_chat_usage_metrics
@@ -122,11 +137,10 @@ class GroundingAgentUitars:
) )
temperature = self.temperature temperature = self.temperature
top_k = self.top_k
try_times = 3 try_times = 3
while not parsed_responses: while not parsed_responses:
if try_times <= 0: if try_times <= 0:
logger.warning(f"Reach max retry times to fetch response from client, as error flag.") logger.warning("Reach max retry times to fetch response from client, as error flag.")
return "client error\nFAIL", [] return "client error\nFAIL", []
try: try:
message_content = "\n".join([msg["content"][0].get("text") or "[image]" for msg in messages]) message_content = "\n".join([msg["content"][0].get("text") or "[image]" for msg in messages])
@@ -163,7 +177,6 @@ class GroundingAgentUitars:
prediction = None prediction = None
try_times -= 1 try_times -= 1
temperature = 1 temperature = 1
top_k = -1
if prediction is None: if prediction is None:
return "client error\nFAIL", [] return "client error\nFAIL", []
@@ -264,9 +277,9 @@ class GroundingAgentUitars:
raise ValueError(f"Unsupported environment type: {environment_type}") raise ValueError(f"Unsupported environment type: {environment_type}")
def _format_messages_for_api(self, instruction: str, current_state: EnvState): def _format_messages_for_api(self, instruction: str, current_state: EnvState):
assert len(self.observations) == len(self.actions) and len(self.actions) == len( assert len(self.observations) == len(self.actions) and len(self.actions) == len(self.thoughts), (
self.thoughts "The number of observations and actions should be the same."
), "The number of observations and actions should be the same." )
self.history_images.append(base64.b64decode(current_state.screenshot)) self.history_images.append(base64.b64decode(current_state.screenshot))
self.observations.append({"screenshot": current_state.screenshot, "accessibility_tree": None}) self.observations.append({"screenshot": current_state.screenshot, "accessibility_tree": None})
@@ -524,7 +537,7 @@ class GroundingAgentUitars:
parsed_actions = [self.parse_action_string(action.replace("\n", "\\n").lstrip()) for action in all_action] parsed_actions = [self.parse_action_string(action.replace("\n", "\\n").lstrip()) for action in all_action]
actions: list[dict] = [] actions: list[dict] = []
for action_instance, raw_str in zip(parsed_actions, all_action): for action_instance, raw_str in zip(parsed_actions, all_action):
if action_instance == None: if action_instance is None:
print(f"Action can't parse: {raw_str}") print(f"Action can't parse: {raw_str}")
raise ValueError(f"Action can't parse: {raw_str}") raise ValueError(f"Action can't parse: {raw_str}")
action_type = action_instance["function"] action_type = action_instance["function"]
@@ -756,7 +769,7 @@ class GroundingAgentUitars:
The pyautogui code string The pyautogui code string
""" """
pyautogui_code = f"import pyautogui\nimport time\n" pyautogui_code = "import pyautogui\nimport time\n"
actions = [] actions = []
if isinstance(responses, dict): if isinstance(responses, dict):
responses = [responses] responses = [responses]
@@ -774,7 +787,7 @@ class GroundingAgentUitars:
if response_id == 0: if response_id == 0:
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n" pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
else: else:
pyautogui_code += f"\ntime.sleep(1)\n" pyautogui_code += "\ntime.sleep(1)\n"
action_dict = response action_dict = response
action_type = action_dict.get("action_type") action_type = action_dict.get("action_type")
@@ -846,17 +859,17 @@ class GroundingAgentUitars:
if content: if content:
if input_swap: if input_swap:
actions += TypeAction() actions += TypeAction()
pyautogui_code += f"\nimport pyperclip" pyautogui_code += "\nimport pyperclip"
pyautogui_code += f"\npyperclip.copy('{stripped_content}')" pyautogui_code += f"\npyperclip.copy('{stripped_content}')"
pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')" pyautogui_code += "\npyautogui.hotkey('ctrl', 'v')"
pyautogui_code += f"\ntime.sleep(0.5)\n" pyautogui_code += "\ntime.sleep(0.5)\n"
if content.endswith("\n") or content.endswith("\\n"): if content.endswith("\n") or content.endswith("\\n"):
pyautogui_code += f"\npyautogui.press('enter')" pyautogui_code += "\npyautogui.press('enter')"
else: else:
pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)" pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)"
pyautogui_code += f"\ntime.sleep(0.5)\n" pyautogui_code += "\ntime.sleep(0.5)\n"
if content.endswith("\n") or content.endswith("\\n"): if content.endswith("\n") or content.endswith("\\n"):
pyautogui_code += f"\npyautogui.press('enter')" pyautogui_code += "\npyautogui.press('enter')"
elif action_type in ["drag", "select"]: elif action_type in ["drag", "select"]:
# Parsing drag or select action based on start and end_boxes # Parsing drag or select action based on start and end_boxes
@@ -869,9 +882,7 @@ class GroundingAgentUitars:
x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2] x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2]
ex = round(float((x1 + x2) / 2) * image_width, 3) ex = round(float((x1 + x2) / 2) * image_width, 3)
ey = round(float((y1 + y2) / 2) * image_height, 3) ey = round(float((y1 + y2) / 2) * image_height, 3)
pyautogui_code += ( pyautogui_code += f"\npyautogui.moveTo({sx}, {sy})\n\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
f"\npyautogui.moveTo({sx}, {sy})\n" f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
)
elif action_type == "scroll": elif action_type == "scroll":
# Parsing scroll action # Parsing scroll action
@@ -888,11 +899,11 @@ class GroundingAgentUitars:
y = None y = None
direction = action_inputs.get("direction", "") direction = action_inputs.get("direction", "")
if x == None: if x is None:
if "up" in direction.lower(): if "up" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(5)" pyautogui_code += "\npyautogui.scroll(5)"
elif "down" in direction.lower(): elif "down" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(-5)" pyautogui_code += "\npyautogui.scroll(-5)"
else: else:
if "up" in direction.lower(): if "up" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})" pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})"
@@ -923,7 +934,7 @@ class GroundingAgentUitars:
pyautogui_code += f"\npyautogui.moveTo({x}, {y})" pyautogui_code += f"\npyautogui.moveTo({x}, {y})"
elif action_type in ["finished"]: elif action_type in ["finished"]:
pyautogui_code = f"DONE" pyautogui_code = "DONE"
else: else:
pyautogui_code += f"\n# Unrecognized action type: {action_type}" pyautogui_code += f"\n# Unrecognized action type: {action_type}"

View File

@@ -11,7 +11,32 @@ from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlo
from khoj.database.models import ChatModel from khoj.database.models import ChatModel
from khoj.processor.conversation.anthropic.utils import is_reasoning_model from khoj.processor.conversation.anthropic.utils import is_reasoning_model
from khoj.processor.conversation.utils import AgentMessage from khoj.processor.conversation.utils import AgentMessage
from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_actions import (
BackAction,
ClickAction,
CursorPositionAction,
DoubleClickAction,
DragAction,
GotoAction,
HoldKeyAction,
KeypressAction,
MouseDownAction,
MouseUpAction,
MoveAction,
NoopAction,
OperatorAction,
Point,
ScreenshotAction,
ScrollAction,
TerminalAction,
TextEditorCreateAction,
TextEditorInsertAction,
TextEditorStrReplaceAction,
TextEditorViewAction,
TripleClickAction,
TypeAction,
WaitAction,
)
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import ( from khoj.processor.operator.operator_environment_base import (
EnvironmentType, EnvironmentType,
@@ -518,7 +543,7 @@ class AnthropicOperatorAgent(OperatorAgent):
def model_default_headers(self) -> list[str]: def model_default_headers(self) -> list[str]:
"""Get the default computer use headers for the given model.""" """Get the default computer use headers for the given model."""
if self.vision_model.name.startswith("claude-3-7-sonnet"): if self.vision_model.name.startswith("claude-3-7-sonnet"):
return [f"computer-use-2025-01-24", "token-efficient-tools-2025-02-19"] return ["computer-use-2025-01-24", "token-efficient-tools-2025-02-19"]
elif self.vision_model.name.startswith("claude-sonnet-4") or self.vision_model.name.startswith("claude-opus-4"): elif self.vision_model.name.startswith("claude-sonnet-4") or self.vision_model.name.startswith("claude-opus-4"):
return ["computer-use-2025-01-24"] return ["computer-use-2025-01-24"]
else: else:
@@ -538,7 +563,7 @@ class AnthropicOperatorAgent(OperatorAgent):
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. * When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. * When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail. * Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. * The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
* The current URL is {current_state.url}. * The current URL is {current_state.url}.
</SYSTEM_CAPABILITY> </SYSTEM_CAPABILITY>
@@ -563,7 +588,7 @@ class AnthropicOperatorAgent(OperatorAgent):
</SYSTEM_CAPABILITY> </SYSTEM_CAPABILITY>
<CONTEXT> <CONTEXT>
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. * The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
</CONTEXT> </CONTEXT>
""" """
).lstrip() ).lstrip()

View File

@@ -1,6 +1,6 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Union from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -12,7 +12,7 @@ from khoj.processor.conversation.utils import (
) )
from khoj.processor.operator.grounding_agent import GroundingAgent from khoj.processor.operator.grounding_agent import GroundingAgent
from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars
from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_actions import OperatorAction, WaitAction
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import ( from khoj.processor.operator.operator_environment_base import (
EnvironmentType, EnvironmentType,
@@ -181,7 +181,7 @@ class BinaryOperatorAgent(OperatorAgent):
elif action.type == "key_down": elif action.type == "key_down":
rendered_parts += [f'**Action**: Press Key "{action.key}"'] rendered_parts += [f'**Action**: Press Key "{action.key}"']
elif action.type == "screenshot" and not current_state.screenshot: elif action.type == "screenshot" and not current_state.screenshot:
rendered_parts += [f"**Error**: Failed to take screenshot"] rendered_parts += ["**Error**: Failed to take screenshot"]
elif action.type == "goto": elif action.type == "goto":
rendered_parts += [f"**Action**: Open URL {action.url}"] rendered_parts += [f"**Action**: Open URL {action.url}"]
else: else:
@@ -317,7 +317,7 @@ class BinaryOperatorAgent(OperatorAgent):
# Introduction # Introduction
* You are Khoj, a smart and resourceful web browsing assistant. You help the user accomplish their task using a web browser. * You are Khoj, a smart and resourceful web browsing assistant. You help the user accomplish their task using a web browser.
* You are given the user's query and screenshots of the browser's state transitions. * You are given the user's query and screenshots of the browser's state transitions.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. * The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
* The current URL is {env_state.url}. * The current URL is {env_state.url}.
# Your Task # Your Task
@@ -362,7 +362,7 @@ class BinaryOperatorAgent(OperatorAgent):
# Introduction # Introduction
* You are Khoj, a smart and resourceful computer assistant. You help the user accomplish their task using a computer. * You are Khoj, a smart and resourceful computer assistant. You help the user accomplish their task using a computer.
* You are given the user's query and screenshots of the computer's state transitions. * You are given the user's query and screenshots of the computer's state transitions.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. * The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
# Your Task # Your Task
* First look at the screenshots carefully to notice all pertinent information. * First look at the screenshots carefully to notice all pertinent information.

View File

@@ -1,6 +1,5 @@
import json import json
import logging import logging
import platform
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from textwrap import dedent from textwrap import dedent
@@ -10,7 +9,23 @@ from openai.types.responses import Response, ResponseOutputItem
from khoj.database.models import ChatModel from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import AgentMessage from khoj.processor.conversation.utils import AgentMessage
from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_actions import (
BackAction,
ClickAction,
DoubleClickAction,
DragAction,
GotoAction,
KeypressAction,
MoveAction,
NoopAction,
OperatorAction,
Point,
RequestUserAction,
ScreenshotAction,
ScrollAction,
TypeAction,
WaitAction,
)
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import ( from khoj.processor.operator.operator_environment_base import (
EnvironmentType, EnvironmentType,
@@ -152,7 +167,7 @@ class OpenAIOperatorAgent(OperatorAgent):
# Add screenshot data in openai message format # Add screenshot data in openai message format
action_result["output"] = { action_result["output"] = {
"type": "input_image", "type": "input_image",
"image_url": f'data:image/webp;base64,{result_content["image"]}', "image_url": f"data:image/webp;base64,{result_content['image']}",
"current_url": result_content["url"], "current_url": result_content["url"],
} }
elif action_result["type"] == "computer_call_output" and idx == len(env_steps) - 1: elif action_result["type"] == "computer_call_output" and idx == len(env_steps) - 1:
@@ -311,7 +326,7 @@ class OpenAIOperatorAgent(OperatorAgent):
elif block.type == "function_call": elif block.type == "function_call":
if block.name == "goto": if block.name == "goto":
args = json.loads(block.arguments) args = json.loads(block.arguments)
render_texts = [f'Open URL: {args.get("url", "[Missing URL]")}'] render_texts = [f"Open URL: {args.get('url', '[Missing URL]')}"]
else: else:
render_texts += [block.name] render_texts += [block.name]
elif block.type == "computer_call": elif block.type == "computer_call":
@@ -351,7 +366,7 @@ class OpenAIOperatorAgent(OperatorAgent):
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. * When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. * When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail. * Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. * The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
* The current URL is {current_state.url}. * The current URL is {current_state.url}.
</SYSTEM_CAPABILITY> </SYSTEM_CAPABILITY>
@@ -374,7 +389,7 @@ class OpenAIOperatorAgent(OperatorAgent):
</SYSTEM_CAPABILITY> </SYSTEM_CAPABILITY>
<CONTEXT> <CONTEXT>
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. * The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
</CONTEXT> </CONTEXT>
""" """
).lstrip() ).lstrip()

View File

@@ -247,7 +247,7 @@ class BrowserEnvironment(Environment):
case "drag": case "drag":
if not isinstance(action, DragAction): if not isinstance(action, DragAction):
raise TypeError(f"Invalid action type for drag") raise TypeError("Invalid action type for drag")
path = action.path path = action.path
if not path: if not path:
error = "Missing path for drag action" error = "Missing path for drag action"

View File

@@ -532,7 +532,7 @@ class ComputerEnvironment(Environment):
else: else:
return {"success": False, "output": process.stdout, "error": process.stderr} return {"success": False, "output": process.stdout, "error": process.stderr}
except asyncio.TimeoutError: except asyncio.TimeoutError:
return {"success": False, "output": "", "error": f"Command timed out after 120 seconds."} return {"success": False, "output": "", "error": "Command timed out after 120 seconds."}
except Exception as e: except Exception as e:
return {"success": False, "output": "", "error": str(e)} return {"success": False, "output": "", "error": str(e)}

View File

@@ -1,4 +1,3 @@
import json # Used for working with JSON data
import os import os
import requests # Used for making HTTP requests import requests # Used for making HTTP requests

View File

@@ -385,7 +385,7 @@ async def read_webpages(
tracer: dict = {}, tracer: dict = {},
): ):
"Infer web pages to read from the query and extract relevant information from them" "Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read") logger.info("Inferring web pages to read")
urls = await infer_webpage_urls( urls = await infer_webpage_urls(
query, query,
max_webpages_to_read, max_webpages_to_read,

View File

@@ -93,7 +93,7 @@ async def run_code(
# Run Code # Run Code
if send_status_func: if send_status_func:
async for event in send_status_func(f"**Running code snippet**"): async for event in send_status_func("**Running code snippet**"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
try: try:
with timer("Chat actor: Execute generated program", logger, log_level=logging.INFO): with timer("Chat actor: Execute generated program", logger, log_level=logging.INFO):

View File

@@ -7,7 +7,6 @@ from typing import List, Optional, Union
import openai import openai
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from fastapi.requests import Request
from fastapi.responses import Response from fastapi.responses import Response
from starlette.authentication import has_required_scope, requires from starlette.authentication import has_required_scope, requires
@@ -94,7 +93,7 @@ def update(
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)
raise HTTPException(status_code=500, detail=error_msg) raise HTTPException(status_code=500, detail=error_msg)
else: else:
logger.info(f"📪 Server indexed content updated via API") logger.info("📪 Server indexed content updated via API")
update_telemetry_state( update_telemetry_state(
request=request, request=request,

View File

@@ -6,12 +6,11 @@ from typing import Dict, List, Optional
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.requests import Request
from fastapi.responses import Response from fastapi.responses import Response
from pydantic import BaseModel from pydantic import BaseModel
from starlette.authentication import has_required_scope, requires from starlette.authentication import has_required_scope, requires
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters from khoj.database.adapters import AgentAdapters, ConversationAdapters
from khoj.database.models import Agent, Conversation, KhojUser, PriceTier from khoj.database.models import Agent, Conversation, KhojUser, PriceTier
from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
from khoj.utils.helpers import ( from khoj.utils.helpers import (

View File

@@ -109,7 +109,7 @@ def post_automation(
except Exception as e: except Exception as e:
logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True) logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True)
return Response( return Response(
content=f"Unable to create automation. Ensure the automation doesn't already exist.", content="Unable to create automation. Ensure the automation doesn't already exist.",
media_type="text/plain", media_type="text/plain",
status_code=500, status_code=500,
) )

View File

@@ -10,7 +10,6 @@ from functools import partial
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import unquote from urllib.parse import unquote
from asgiref.sync import sync_to_async
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
Depends, Depends,
@@ -32,10 +31,10 @@ from khoj.database.adapters import (
PublicConversationAdapters, PublicConversationAdapters,
aget_user_name, aget_user_name,
) )
from khoj.database.models import Agent, ChatMessageModel, KhojUser from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import is_local_api from khoj.processor.conversation.openai.utils import is_local_api
from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.conversation.prompts import no_entries_found
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
OperatorRun, OperatorRun,
ResponseWithThought, ResponseWithThought,
@@ -65,11 +64,8 @@ from khoj.routers.helpers import (
acreate_title_from_history, acreate_title_from_history,
agenerate_chat_response, agenerate_chat_response,
aget_data_sources_and_output_format, aget_data_sources_and_output_format,
construct_automation_created_message,
create_automation,
gather_raw_query_files, gather_raw_query_files,
generate_mermaidjs_diagram, generate_mermaidjs_diagram,
generate_summary_from_files,
get_conversation_command, get_conversation_command,
get_message_from_queue, get_message_from_queue,
is_query_empty, is_query_empty,
@@ -89,13 +85,11 @@ from khoj.utils.helpers import (
convert_image_to_webp, convert_image_to_webp,
get_country_code_from_timezone, get_country_code_from_timezone,
get_country_name_from_timezone, get_country_name_from_timezone,
get_device,
is_env_var_true, is_env_var_true,
is_none_or_empty, is_none_or_empty,
is_operator_enabled, is_operator_enabled,
) )
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
ChatRequestBody,
FileAttachment, FileAttachment,
FileFilterRequest, FileFilterRequest,
FilesFilterRequest, FilesFilterRequest,
@@ -689,7 +683,6 @@ async def event_generator(
region = body.region region = body.region
country = body.country or get_country_name_from_timezone(body.timezone) country = body.country or get_country_name_from_timezone(body.timezone)
country_code = body.country_code or get_country_code_from_timezone(body.timezone) country_code = body.country_code or get_country_code_from_timezone(body.timezone)
timezone = body.timezone
raw_images = body.images raw_images = body.images
raw_query_files = body.files raw_query_files = body.files
@@ -853,7 +846,8 @@ async def event_generator(
if ( if (
len(train_of_thought) > 0 len(train_of_thought) > 0
and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
and type(train_of_thought[-1]["data"]) == type(data) == str and isinstance(train_of_thought[-1]["data"], str)
and isinstance(data, str)
): ):
train_of_thought[-1]["data"] += data train_of_thought[-1]["data"] += data
else: else:
@@ -1075,11 +1069,11 @@ async def event_generator(
# researched_results = await extract_relevant_info(q, researched_results, agent) # researched_results = await extract_relevant_info(q, researched_results, agent)
if state.verbose > 1: if state.verbose > 1:
logger.debug(f'Researched Results: {"".join(r.summarizedResult or "" for r in research_results)}') logger.debug(f"Researched Results: {''.join(r.summarizedResult or '' for r in research_results)}")
# Gather Context # Gather Context
## Extract Document References ## Extract Document References
if not ConversationCommand.Research in conversation_commands: if ConversationCommand.Research not in conversation_commands:
try: try:
async for result in search_documents( async for result in search_documents(
q, q,
@@ -1218,7 +1212,7 @@ async def event_generator(
else: else:
code_results = result code_results = result
except ValueError as e: except ValueError as e:
program_execution_context.append(f"Failed to run code") program_execution_context.append("Failed to run code")
logger.warning( logger.warning(
f"Failed to use code tool: {e}. Attempting to respond without code results", f"Failed to use code tool: {e}. Attempting to respond without code results",
exc_info=True, exc_info=True,
@@ -1297,7 +1291,7 @@ async def event_generator(
inferred_queries.append(improved_image_prompt) inferred_queries.append(improved_image_prompt)
if generated_image is None or status_code != 200: if generated_image is None or status_code != 200:
program_execution_context.append(f"Failed to generate image with {improved_image_prompt}") program_execution_context.append(f"Failed to generate image with {improved_image_prompt}")
async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"): async for result in send_event(ChatEvent.STATUS, "Failed to generate image"):
yield result yield result
else: else:
generated_images.append(generated_image) generated_images.append(generated_image)
@@ -1315,7 +1309,7 @@ async def event_generator(
yield result yield result
if ConversationCommand.Diagram in conversation_commands: if ConversationCommand.Diagram in conversation_commands:
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"): async for result in send_event(ChatEvent.STATUS, "Creating diagram"):
yield result yield result
inferred_queries = [] inferred_queries = []
@@ -1372,7 +1366,7 @@ async def event_generator(
return return
## Generate Text Output ## Generate Text Output
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): async for result in send_event(ChatEvent.STATUS, "**Generating a well-informed response**"):
yield result yield result
llm_response, chat_metadata = await agenerate_chat_response( llm_response, chat_metadata = await agenerate_chat_response(

View File

@@ -3,7 +3,6 @@ import logging
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.requests import Request
from fastapi.responses import Response from fastapi.responses import Response
from starlette.authentication import has_required_scope, requires from starlette.authentication import has_required_scope, requires

View File

@@ -117,7 +117,7 @@ async def subscribe(request: Request):
) )
logger.log(logging.INFO, f"🥳 New User Created: {user.user.uuid}") logger.log(logging.INFO, f"🥳 New User Created: {user.user.uuid}")
logger.info(f'Stripe subscription {event["type"]} for {customer_email}') logger.info(f"Stripe subscription {event['type']} for {customer_email}")
return {"success": success} return {"success": success}

View File

@@ -44,7 +44,7 @@ async def send_magic_link_email(email, unique_id, host):
{ {
"sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"), "sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"),
"to": email, "to": email,
"subject": f"Your login code to Khoj", "subject": "Your login code to Khoj",
"html": html_content, "html": html_content,
} }
) )
@@ -98,11 +98,11 @@ async def send_query_feedback(uquery, kquery, sentiment, user_email):
user_email=user_email if not is_none_or_empty(user_email) else "N/A", user_email=user_email if not is_none_or_empty(user_email) else "N/A",
) )
# send feedback to fixed account # send feedback to fixed account
r = resend.Emails.send( resend.Emails.send(
{ {
"sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"), "sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"),
"to": "team@khoj.dev", "to": "team@khoj.dev",
"subject": f"User Feedback", "subject": "User Feedback",
"html": html_content, "html": html_content,
} }
) )
@@ -127,7 +127,7 @@ def send_task_email(name, email, query, result, subject, is_image=False):
r = resend.Emails.send( r = resend.Emails.send(
{ {
"sender": f'Khoj <{os.environ.get("RESEND_EMAIL", "khoj@khoj.dev")}>', "sender": f"Khoj <{os.environ.get('RESEND_EMAIL', 'khoj@khoj.dev')}>",
"to": email, "to": email,
"subject": f"{subject}", "subject": f"{subject}",
"html": html_content, "html": html_content,

View File

@@ -1,6 +1,5 @@
import asyncio import asyncio
import base64 import base64
import concurrent.futures
import fnmatch import fnmatch
import hashlib import hashlib
import json import json
@@ -47,14 +46,12 @@ from khoj.database.adapters import (
EntryAdapters, EntryAdapters,
FileObjectAdapters, FileObjectAdapters,
aget_user_by_email, aget_user_by_email,
ais_user_subscribed,
create_khoj_token, create_khoj_token,
get_default_search_model, get_default_search_model,
get_khoj_tokens, get_khoj_tokens,
get_user_name, get_user_name,
get_user_notion_config, get_user_notion_config,
get_user_subscription_state, get_user_subscription_state,
is_user_subscribed,
run_with_process_lock, run_with_process_lock,
) )
from khoj.database.models import ( from khoj.database.models import (
@@ -160,7 +157,7 @@ def validate_chat_model(user: KhojUser):
async def is_ready_to_chat(user: KhojUser): async def is_ready_to_chat(user: KhojUser):
user_chat_model = await ConversationAdapters.aget_user_chat_model(user) user_chat_model = await ConversationAdapters.aget_user_chat_model(user)
if user_chat_model == None: if user_chat_model is None:
user_chat_model = await ConversationAdapters.aget_default_chat_model(user) user_chat_model = await ConversationAdapters.aget_default_chat_model(user)
if ( if (
@@ -581,7 +578,7 @@ async def generate_online_subqueries(
) )
return {q} return {q}
return response return response
except Exception as e: except Exception:
logger.error(f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}") logger.error(f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}")
return {q} return {q}
@@ -1172,8 +1169,8 @@ async def search_documents(
agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent) agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent)
if ( if (
not ConversationCommand.Notes in conversation_commands ConversationCommand.Notes not in conversation_commands
and not ConversationCommand.Default in conversation_commands and ConversationCommand.Default not in conversation_commands
and not agent_has_entries and not agent_has_entries
): ):
yield compiled_references, inferred_queries, q yield compiled_references, inferred_queries, q
@@ -1325,8 +1322,8 @@ async def extract_questions(
logger.error(f"Invalid response for constructing subqueries: {response}") logger.error(f"Invalid response for constructing subqueries: {response}")
return [query] return [query]
return queries return queries
except: except Exception:
logger.warning(f"LLM returned invalid JSON. Falling back to using user message as search query.") logger.warning("LLM returned invalid JSON. Falling back to using user message as search query.")
return [query] return [query]
@@ -1351,7 +1348,7 @@ async def execute_search(
return results return results
if q is None or q == "": if q is None or q == "":
logger.warning(f"No query param (q) passed in API call to initiate search") logger.warning("No query param (q) passed in API call to initiate search")
return results return results
# initialize variables # initialize variables
@@ -1364,7 +1361,7 @@ async def execute_search(
if user: if user:
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}" query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
if query_cache_key in state.query_cache[user.uuid]: if query_cache_key in state.query_cache[user.uuid]:
logger.debug(f"Return response from query cache") logger.debug("Return response from query cache")
return state.query_cache[user.uuid][query_cache_key] return state.query_cache[user.uuid][query_cache_key]
# Encode query with filter terms removed # Encode query with filter terms removed
@@ -1875,8 +1872,8 @@ class ApiUserRateLimiter:
user: KhojUser = websocket.scope["user"].object user: KhojUser = websocket.scope["user"].object
subscribed = has_required_scope(websocket, ["premium"]) subscribed = has_required_scope(websocket, ["premium"])
current_window = "today" if self.window == 60 * 60 * 24 else f"now" current_window = "today" if self.window == 60 * 60 * 24 else "now"
next_window = "tomorrow" if self.window == 60 * 60 * 24 else f"in a bit" next_window = "tomorrow" if self.window == 60 * 60 * 24 else "in a bit"
common_message_prefix = f"I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for {current_window}." common_message_prefix = f"I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for {current_window}."
# Remove requests outside of the time window # Remove requests outside of the time window
@@ -2219,7 +2216,7 @@ def should_notify(original_query: str, executed_query: str, ai_response: str, us
should_notify_result = response["decision"] == "Yes" should_notify_result = response["decision"] == "Yes"
reason = response.get("reason", "unknown") reason = response.get("reason", "unknown")
logger.info( logger.info(
f'Decided to {"not " if not should_notify_result else ""}notify user of automation response because of reason: {reason}.' f"Decided to {'not ' if not should_notify_result else ''}notify user of automation response because of reason: {reason}."
) )
return should_notify_result return should_notify_result
except Exception as e: except Exception as e:
@@ -2313,7 +2310,7 @@ def scheduled_chat(
response_map = raw_response.json() response_map = raw_response.json()
ai_response = response_map.get("response") or response_map.get("image") ai_response = response_map.get("response") or response_map.get("image")
is_image = False is_image = False
if type(ai_response) == dict: if isinstance(ai_response, dict):
is_image = ai_response.get("image") is not None is_image = ai_response.get("image") is not None
else: else:
ai_response = raw_response.text ai_response = raw_response.text
@@ -2460,12 +2457,12 @@ async def aschedule_automation(
def construct_automation_created_message(automation: Job, crontime: str, query_to_run: str, subject: str): def construct_automation_created_message(automation: Job, crontime: str, query_to_run: str, subject: str):
# Display next run time in user timezone instead of UTC # Display next run time in user timezone instead of UTC
schedule = f'{cron_descriptor.get_description(crontime)} {automation.next_run_time.strftime("%Z")}' schedule = f"{cron_descriptor.get_description(crontime)} {automation.next_run_time.strftime('%Z')}"
next_run_time = automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z") next_run_time = automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z")
# Remove /automated_task prefix from inferred_query # Remove /automated_task prefix from inferred_query
unprefixed_query_to_run = re.sub(r"^\/automated_task\s*", "", query_to_run) unprefixed_query_to_run = re.sub(r"^\/automated_task\s*", "", query_to_run)
# Create the automation response # Create the automation response
automation_icon_url = f"/static/assets/icons/automation.svg" automation_icon_url = "/static/assets/icons/automation.svg"
return f""" return f"""
### ![]({automation_icon_url}) Created Automation ### ![]({automation_icon_url}) Created Automation
- Subject: **{subject}** - Subject: **{subject}**
@@ -2713,13 +2710,13 @@ def configure_content(
t: Optional[state.SearchType] = state.SearchType.All, t: Optional[state.SearchType] = state.SearchType.All,
) -> bool: ) -> bool:
success = True success = True
if t == None: if t is None:
t = state.SearchType.All t = state.SearchType.All
if t is not None and t in [type.value for type in state.SearchType]: if t is not None and t in [type.value for type in state.SearchType]:
t = state.SearchType(t) t = state.SearchType(t)
if t is not None and not t.value in [type.value for type in state.SearchType]: if t is not None and t.value not in [type.value for type in state.SearchType]:
logger.warning(f"🚨 Invalid search type: {t}") logger.warning(f"🚨 Invalid search type: {t}")
return False return False
@@ -2988,7 +2985,7 @@ async def grep_files(
query += f" {' and '.join(context_info)}" query += f" {' and '.join(context_info)}"
if line_count > max_results: if line_count > max_results:
if lines_before or lines_after: if lines_before or lines_after:
query += f" for" query += " for"
query += f" first {max_results} results" query += f" first {max_results} results"
return query return query

View File

@@ -15,7 +15,6 @@ from khoj.processor.conversation.utils import (
ResearchIteration, ResearchIteration,
ToolCall, ToolCall,
construct_iteration_history, construct_iteration_history,
construct_structured_message,
construct_tool_chat_history, construct_tool_chat_history,
load_complex_json, load_complex_json,
) )
@@ -24,7 +23,6 @@ from khoj.processor.tools.online_search import read_webpages_content, search_onl
from khoj.processor.tools.run_code import run_code from khoj.processor.tools.run_code import run_code
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ChatEvent, ChatEvent,
generate_summary_from_files,
get_message_from_queue, get_message_from_queue,
grep_files, grep_files,
list_files, list_files,
@@ -184,7 +182,7 @@ async def apick_next_tool(
# TODO: Handle multiple tool calls. # TODO: Handle multiple tool calls.
response_text = response.text response_text = response.text
parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0] parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0]
except Exception as e: except Exception:
# Otherwise assume the model has decided to end the research run and respond to the user. # Otherwise assume the model has decided to end the research run and respond to the user.
parsed_response = ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None) parsed_response = ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None)
@@ -199,7 +197,7 @@ async def apick_next_tool(
if i.warning is None and isinstance(i.query, ToolCall) if i.warning is None and isinstance(i.query, ToolCall)
} }
if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations: if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different." warning = "Repeated tool, query combination detected. Skipping iteration. Try something different."
# Only send client status updates if we'll execute this iteration and model has thoughts to share. # Only send client status updates if we'll execute this iteration and model has thoughts to share.
elif send_status_func and not is_none_or_empty(response.thought): elif send_status_func and not is_none_or_empty(response.thought):
async for event in send_status_func(response.thought): async for event in send_status_func(response.thought):

View File

@@ -4,12 +4,10 @@ from typing import List
class BaseFilter(ABC): class BaseFilter(ABC):
@abstractmethod @abstractmethod
def get_filter_terms(self, query: str) -> List[str]: def get_filter_terms(self, query: str) -> List[str]: ...
...
def can_filter(self, raw_query: str) -> bool: def can_filter(self, raw_query: str) -> bool:
return len(self.get_filter_terms(raw_query)) > 0 return len(self.get_filter_terms(raw_query)) > 0
@abstractmethod @abstractmethod
def defilter(self, query: str) -> str: def defilter(self, query: str) -> str: ...
...

View File

@@ -9,9 +9,8 @@ from asgiref.sync import sync_to_async
from sentence_transformers import util from sentence_transformers import util
from khoj.database.adapters import EntryAdapters, get_default_search_model from khoj.database.adapters import EntryAdapters, get_default_search_model
from khoj.database.models import Agent from khoj.database.models import Agent, KhojUser
from khoj.database.models import Entry as DbEntry from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import get_absolute_path, timer from khoj.utils.helpers import get_absolute_path, timer

View File

@@ -77,7 +77,7 @@ class AsyncIteratorWrapper:
def is_none_or_empty(item): def is_none_or_empty(item):
return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == "" return item is None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
def to_snake_case_from_dash(item: str): def to_snake_case_from_dash(item: str):
@@ -97,7 +97,7 @@ def get_from_dict(dictionary, *args):
Returns: dictionary[args[0]][args[1]]... or None if any keys missing""" Returns: dictionary[args[0]][args[1]]... or None if any keys missing"""
current = dictionary current = dictionary
for arg in args: for arg in args:
if not hasattr(current, "__iter__") or not arg in current: if not hasattr(current, "__iter__") or arg not in current:
return None return None
current = current[arg] current = current[arg]
return current return current
@@ -751,7 +751,7 @@ def is_valid_url(url: str) -> bool:
try: try:
result = urlparse(url.strip()) result = urlparse(url.strip())
return all([result.scheme, result.netloc]) return all([result.scheme, result.netloc])
except: except Exception:
return False return False
@@ -759,7 +759,7 @@ def is_internet_connected():
try: try:
response = requests.head("https://www.google.com") response = requests.head("https://www.google.com")
return response.status_code == 200 return response.status_code == 200
except: except Exception:
return False return False

View File

@@ -60,9 +60,7 @@ def initialization(interactive: bool = True):
] ]
default_chat_models = known_available_models + other_available_models default_chat_models = known_available_models + other_available_models
except Exception as e: except Exception as e:
logger.warning( logger.warning(f"⚠️ Failed to fetch {provider} chat models. Fallback to default models. Error: {str(e)}")
f"⚠️ Failed to fetch {provider} chat models. Fallback to default models. Error: {str(e)}"
)
# Set up OpenAI's online chat models # Set up OpenAI's online chat models
openai_configured, openai_provider = _setup_chat_model_provider( openai_configured, openai_provider = _setup_chat_model_provider(

View File

@@ -8,12 +8,10 @@ from tqdm import trange
class BaseEncoder(ABC): class BaseEncoder(ABC):
@abstractmethod @abstractmethod
def __init__(self, model_name: str, device: torch.device = None, **kwargs): def __init__(self, model_name: str, device: torch.device = None, **kwargs): ...
...
@abstractmethod @abstractmethod
def encode(self, entries: List[str], device: torch.device = None, **kwargs) -> torch.Tensor: def encode(self, entries: List[str], device: torch.device = None, **kwargs) -> torch.Tensor: ...
...
class OpenAI(BaseEncoder): class OpenAI(BaseEncoder):

View File

@@ -1,8 +1,7 @@
# System Packages # System Packages
import json import json
import uuid import uuid
from pathlib import Path from typing import List, Optional
from typing import Dict, List, Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -2,7 +2,7 @@ import os
import threading import threading
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Dict, List
from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.schedulers.background import BackgroundScheduler
from openai import OpenAI from openai import OpenAI

View File

@@ -30,7 +30,7 @@ def v1_telemetry(telemetry_data: List[Dict[str, str]]):
try: try:
for row in telemetry_data: for row in telemetry_data:
posthog.capture(row["server_id"], "api_request", row) posthog.capture(row["server_id"], "api_request", row)
except Exception as e: except Exception:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail="Could not POST equest to new khoj telemetry server. Contact developer to get this fixed.", detail="Could not POST equest to new khoj telemetry server. Contact developer to get this fixed.",

View File

@@ -326,7 +326,7 @@ File statistics:
- Code examples: Yes - Code examples: Yes
- Purpose: Stress testing atomic agent updates - Purpose: Stress testing atomic agent updates
{'Additional padding content. ' * 20} {"Additional padding content. " * 20}
End of file {i}. End of file {i}.
""" """

View File

@@ -462,7 +462,7 @@ def evaluate_response_with_gemini(
Ground Truth: {ground_truth} Ground Truth: {ground_truth}
Provide your evaluation in the following json format: Provide your evaluation in the following json format:
{"explanation:" "[How you made the decision?)", "decision:" "(TRUE if response contains key information, FALSE otherwise)"} {"explanation:[How you made the decision?)", "decision:(TRUE if response contains key information, FALSE otherwise)"}
""" """
gemini_api_url = ( gemini_api_url = (
f"https://generativelanguage.googleapis.com/v1beta/models/{eval_model}:generateContent?key={GEMINI_API_KEY}" f"https://generativelanguage.googleapis.com/v1beta/models/{eval_model}:generateContent?key={GEMINI_API_KEY}"
@@ -557,7 +557,7 @@ def process_batch(batch, batch_start, results, dataset_length, response_evaluato
--------- ---------
Decision: {colored_decision} Decision: {colored_decision}
Accuracy: {running_accuracy:.2%} Accuracy: {running_accuracy:.2%}
Progress: {running_total_count.get()/dataset_length:.2%} Progress: {running_total_count.get() / dataset_length:.2%}
Index: {current_index} Index: {current_index}
Question: {prompt} Question: {prompt}
Expected Answer: {answer} Expected Answer: {answer}

View File

@@ -20,7 +20,7 @@ def test_create_default_agent(default_user: KhojUser):
assert agent.input_tools == [] assert agent.input_tools == []
assert agent.output_modes == [] assert agent.output_modes == []
assert agent.privacy_level == Agent.PrivacyLevel.PUBLIC assert agent.privacy_level == Agent.PrivacyLevel.PUBLIC
assert agent.managed_by_admin == True assert agent.managed_by_admin
@pytest.mark.anyio @pytest.mark.anyio
@@ -178,7 +178,7 @@ async def test_multiple_agents_with_knowledge_base_and_users(
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
): ):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown") full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent( await AgentAdapters.aupdate_agent(
default_user2, default_user2,
"Test Agent", "Test Agent",
"Test Personality", "Test Personality",
@@ -290,17 +290,17 @@ async def test_large_knowledge_base_atomic_update(
assert len(final_entries) > initial_entries_count, "Should have more entries after update" assert len(final_entries) > initial_entries_count, "Should have more entries after update"
# With 180 files, we should have many entries (each file creates multiple entries) # With 180 files, we should have many entries (each file creates multiple entries)
assert ( assert len(final_entries) >= expected_file_count, (
len(final_entries) >= expected_file_count f"Expected at least {expected_file_count} entries, got {len(final_entries)}"
), f"Expected at least {expected_file_count} entries, got {len(final_entries)}" )
# Verify no partial state - all entries should correspond to the final file set # Verify no partial state - all entries should correspond to the final file set
entry_file_paths = {entry.file_path for entry in final_entries} entry_file_paths = {entry.file_path for entry in final_entries}
# All file objects should have corresponding entries # All file objects should have corresponding entries
assert file_paths_in_db.issubset( assert file_paths_in_db.issubset(entry_file_paths), (
entry_file_paths "All file objects should have corresponding entries - atomic update verification"
), "All file objects should have corresponding entries - atomic update verification" )
# Additional stress test: verify referential integrity # Additional stress test: verify referential integrity
# Count entries per file to ensure no partial file processing # Count entries per file to ensure no partial file processing
@@ -333,7 +333,7 @@ async def test_concurrent_agent_updates_atomicity(
test_files = available_files # Use all available files for the stress test test_files = available_files # Use all available files for the stress test
# Create initial agent # Create initial agent
agent = await AgentAdapters.aupdate_agent( await AgentAdapters.aupdate_agent(
default_user2, default_user2,
"Concurrent Test Agent", "Concurrent Test Agent",
"Test concurrent updates", "Test concurrent updates",
@@ -391,14 +391,14 @@ async def test_concurrent_agent_updates_atomicity(
file_object_paths = {fo.file_name for fo in final_file_objects} file_object_paths = {fo.file_name for fo in final_file_objects}
# All entries should have corresponding file objects # All entries should have corresponding file objects
assert entry_file_paths.issubset( assert entry_file_paths.issubset(file_object_paths), (
file_object_paths "All entries should have corresponding file objects - indicates atomic update worked"
), "All entries should have corresponding file objects - indicates atomic update worked" )
except Exception as e: except Exception as e:
# If we get database integrity errors, that's actually expected behavior # If we get database integrity errors, that's actually expected behavior
# with proper atomic transactions - they should fail cleanly rather than # with proper atomic transactions - they should fail cleanly rather than
# allowing partial updates # allowing partial updates
assert ( assert "database" in str(e).lower() or "integrity" in str(e).lower(), (
"database" in str(e).lower() or "integrity" in str(e).lower() f"Expected database/integrity error with concurrent updates, got: {e}"
), f"Expected database/integrity error with concurrent updates, got: {e}" )

View File

@@ -5,7 +5,6 @@ from urllib.parse import quote
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from PIL import Image
from khoj.configure import configure_routes, configure_search_types from khoj.configure import configure_routes, configure_search_types
from khoj.database.adapters import EntryAdapters from khoj.database.adapters import EntryAdapters
@@ -101,7 +100,7 @@ def test_update_with_invalid_content_type(client):
headers = {"Authorization": "Bearer kk-secret"} headers = {"Authorization": "Bearer kk-secret"}
# Act # Act
response = client.get(f"/api/update?t=invalid_content_type", headers=headers) response = client.get("/api/update?t=invalid_content_type", headers=headers)
# Assert # Assert
assert response.status_code == 422 assert response.status_code == 422
@@ -114,7 +113,7 @@ def test_regenerate_with_invalid_content_type(client):
headers = {"Authorization": "Bearer kk-secret"} headers = {"Authorization": "Bearer kk-secret"}
# Act # Act
response = client.get(f"/api/update?force=true&t=invalid_content_type", headers=headers) response = client.get("/api/update?force=true&t=invalid_content_type", headers=headers)
# Assert # Assert
assert response.status_code == 422 assert response.status_code == 422
@@ -238,13 +237,13 @@ def test_regenerate_with_valid_content_type(client):
def test_regenerate_with_github_fails_without_pat(client): def test_regenerate_with_github_fails_without_pat(client):
# Act # Act
headers = {"Authorization": "Bearer kk-secret"} headers = {"Authorization": "Bearer kk-secret"}
response = client.get(f"/api/update?force=true&t=github", headers=headers) response = client.get("/api/update?force=true&t=github", headers=headers)
# Arrange # Arrange
files = get_sample_files_data() files = get_sample_files_data()
# Act # Act
response = client.patch(f"/api/content?t=github", files=files, headers=headers) response = client.patch("/api/content?t=github", files=files, headers=headers)
# Assert # Assert
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github" assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github"
@@ -270,7 +269,7 @@ def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user) text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
# Act # Act
response = client.get(f"/api/content/types", headers=headers) response = client.get("/api/content/types", headers=headers)
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@@ -286,7 +285,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
client = TestClient(fastapi_app) client = TestClient(fastapi_app)
# Act # Act
response = client.get(f"/api/content/types") response = client.get("/api/content/types")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@@ -454,8 +453,8 @@ def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojAp
headers = {"Authorization": f"Bearer {api_user2.token}"} headers = {"Authorization": f"Bearer {api_user2.token}"}
# Act # Act
auth_response = chat_client_with_auth.post(f"/api/chat", json={"q": query}, headers=headers) auth_response = chat_client_with_auth.post("/api/chat", json={"q": query}, headers=headers)
no_auth_response = chat_client_with_auth.post(f"/api/chat", json={"q": query}) no_auth_response = chat_client_with_auth.post("/api/chat", json={"q": query})
# Assert # Assert
assert auth_response.status_code == 200 assert auth_response.status_code == 200

View File

@@ -77,12 +77,12 @@ class TestTruncateMessage:
# Assert # Assert
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert ( assert len(chat_history) == 1, (
len(chat_history) == 1 "Only most recent message should be present as it itself is larger than context size"
), "Only most recent message should be present as it itself is larger than context size" )
assert len(truncated_chat_history[0].content) < len( assert len(truncated_chat_history[0].content) < len(copy_big_chat_message.content), (
copy_big_chat_message.content "message content list should be modified"
), "message content list should be modified" )
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved" assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
@@ -101,9 +101,9 @@ class TestTruncateMessage:
# Assert # Assert
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert ( assert len(chat_history) == 1, (
len(chat_history) == 1 "Only most recent message should be present as it itself is larger than context size"
), "Only most recent message should be present as it itself is larger than context size" )
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
@@ -150,9 +150,9 @@ class TestTruncateMessage:
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size" assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
assert ( assert len(chat_messages) == 1, (
len(chat_messages) == 1 "Only most recent message should be present as it itself is larger than context size"
), "Only most recent message should be present as it itself is larger than context size" )
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
@@ -172,9 +172,9 @@ class TestTruncateMessage:
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size" assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
assert ( assert len(chat_messages) == 1, (
len(chat_messages) == 1 "Only most recent message should be present as it itself is larger than context size"
), "Only most recent message should be present as it itself is larger than context size" )
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"

View File

@@ -162,15 +162,15 @@ def test_date_extraction():
assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], "Expected d.m.Y structured date to be extracted" assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], "Expected d.m.Y structured date to be extracted"
extracted_dates = DateFilter().extract_dates("CLOCK: [1984-04-01 Sun 09:50]--[1984-04-01 Sun 10:10] => 24:20") extracted_dates = DateFilter().extract_dates("CLOCK: [1984-04-01 Sun 09:50]--[1984-04-01 Sun 10:10] => 24:20")
assert extracted_dates == [ assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], (
datetime(1984, 4, 1, 0, 0, 0) "Expected single deduplicated date extracted from logbook entry"
], "Expected single deduplicated date extracted from logbook entry" )
extracted_dates = DateFilter().extract_dates("CLOCK: [1984/03/31 mer 09:50]--[1984/04/01 mer 10:10] => 24:20") extracted_dates = DateFilter().extract_dates("CLOCK: [1984/03/31 mer 09:50]--[1984/04/01 mer 10:10] => 24:20")
expected_dates = [datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 3, 31, 0, 0, 0)] expected_dates = [datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 3, 31, 0, 0, 0)]
assert all( assert all([dt in extracted_dates for dt in expected_dates]), (
[dt in extracted_dates for dt in expected_dates] "Expected multiple different dates extracted from logbook entry"
), "Expected multiple different dates extracted from logbook entry" )
def test_natual_date_extraction(): def test_natual_date_extraction():
@@ -187,9 +187,9 @@ def test_natual_date_extraction():
assert datetime(1984, 4, 4, 0, 0, 0) in extracted_dates, "Expected natural date to be extracted" assert datetime(1984, 4, 4, 0, 0, 0) in extracted_dates, "Expected natural date to be extracted"
extracted_dates = DateFilter().extract_dates("head 11th april 1984 tail") extracted_dates = DateFilter().extract_dates("head 11th april 1984 tail")
assert ( assert datetime(1984, 4, 11, 0, 0, 0) in extracted_dates, (
datetime(1984, 4, 11, 0, 0, 0) in extracted_dates "Expected natural date with lowercase month to be extracted"
), "Expected natural date with lowercase month to be extracted" )
extracted_dates = DateFilter().extract_dates("head 23rd april 84 tail") extracted_dates = DateFilter().extract_dates("head 23rd april 84 tail")
assert datetime(1984, 4, 23, 0, 0, 0) in extracted_dates, "Expected natural date with 2-digit year to be extracted" assert datetime(1984, 4, 23, 0, 0, 0) in extracted_dates, "Expected natural date with 2-digit year to be extracted"
@@ -201,16 +201,16 @@ def test_natual_date_extraction():
assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], "Expected partial natural date to be extracted" assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], "Expected partial natural date to be extracted"
extracted_dates = DateFilter().extract_dates("head Apr 1984 tail") extracted_dates = DateFilter().extract_dates("head Apr 1984 tail")
assert extracted_dates == [ assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], (
datetime(1984, 4, 1, 0, 0, 0) "Expected partial natural date with short month to be extracted"
], "Expected partial natural date with short month to be extracted" )
extracted_dates = DateFilter().extract_dates("head apr 1984 tail") extracted_dates = DateFilter().extract_dates("head apr 1984 tail")
assert extracted_dates == [ assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], (
datetime(1984, 4, 1, 0, 0, 0) "Expected partial natural date with lowercase month to be extracted"
], "Expected partial natural date with lowercase month to be extracted" )
extracted_dates = DateFilter().extract_dates("head apr 84 tail") extracted_dates = DateFilter().extract_dates("head apr 84 tail")
assert extracted_dates == [ assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], (
datetime(1984, 4, 1, 0, 0, 0) "Expected partial natural date with 2-digit year to be extracted"
], "Expected partial natural date with 2-digit year to be extracted" )

View File

@@ -1,5 +1,3 @@
import os
from khoj.processor.content.images.image_to_entries import ImageToEntries from khoj.processor.content.images.image_to_entries import ImageToEntries

View File

@@ -8,7 +8,7 @@ from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntrie
def test_extract_markdown_with_no_headings(tmp_path): def test_extract_markdown_with_no_headings(tmp_path):
"Convert markdown file with no heading to entry format." "Convert markdown file with no heading to entry format."
# Arrange # Arrange
entry = f""" entry = """
- Bullet point 1 - Bullet point 1
- Bullet point 2 - Bullet point 2
""" """
@@ -35,7 +35,7 @@ def test_extract_markdown_with_no_headings(tmp_path):
def test_extract_single_markdown_entry(tmp_path): def test_extract_single_markdown_entry(tmp_path):
"Convert markdown from single file to entry format." "Convert markdown from single file to entry format."
# Arrange # Arrange
entry = f"""### Heading entry = """### Heading
\t\r \t\r
Body Line 1 Body Line 1
""" """
@@ -55,7 +55,7 @@ def test_extract_single_markdown_entry(tmp_path):
def test_extract_multiple_markdown_entries(tmp_path): def test_extract_multiple_markdown_entries(tmp_path):
"Convert multiple markdown from single file to entry format." "Convert multiple markdown from single file to entry format."
# Arrange # Arrange
entry = f""" entry = """
### Heading 1 ### Heading 1
\t\r \t\r
Heading 1 Body Line 1 Heading 1 Body Line 1
@@ -81,7 +81,7 @@ def test_extract_multiple_markdown_entries(tmp_path):
def test_extract_entries_with_different_level_headings(tmp_path): def test_extract_entries_with_different_level_headings(tmp_path):
"Extract markdown entries with different level headings." "Extract markdown entries with different level headings."
# Arrange # Arrange
entry = f""" entry = """
# Heading 1 # Heading 1
## Sub-Heading 1.1 ## Sub-Heading 1.1
# Heading 2 # Heading 2
@@ -104,7 +104,7 @@ def test_extract_entries_with_different_level_headings(tmp_path):
def test_extract_entries_with_non_incremental_heading_levels(tmp_path): def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
"Extract markdown entries when deeper child level before shallower child level." "Extract markdown entries when deeper child level before shallower child level."
# Arrange # Arrange
entry = f""" entry = """
# Heading 1 # Heading 1
#### Sub-Heading 1.1 #### Sub-Heading 1.1
## Sub-Heading 1.2 ## Sub-Heading 1.2
@@ -129,7 +129,7 @@ def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
def test_extract_entries_with_text_before_headings(tmp_path): def test_extract_entries_with_text_before_headings(tmp_path):
"Extract markdown entries with some text before any headings." "Extract markdown entries with some text before any headings."
# Arrange # Arrange
entry = f""" entry = """
Text before headings Text before headings
# Heading 1 # Heading 1
body line 1 body line 1
@@ -149,15 +149,15 @@ body line 2
assert len(entries[1]) == 3 assert len(entries[1]) == 3
assert entries[1][0].raw == "\nText before headings" assert entries[1][0].raw == "\nText before headings"
assert entries[1][1].raw == "# Heading 1\nbody line 1" assert entries[1][1].raw == "# Heading 1\nbody line 1"
assert ( assert entries[1][2].raw == "# Heading 1\n## Heading 2\nbody line 2\n", (
entries[1][2].raw == "# Heading 1\n## Heading 2\nbody line 2\n" "Ensure raw entry includes heading ancestory"
), "Ensure raw entry includes heading ancestory" )
def test_parse_markdown_file_into_single_entry_if_small(tmp_path): def test_parse_markdown_file_into_single_entry_if_small(tmp_path):
"Parse markdown file into single entry if it fits within the token limits." "Parse markdown file into single entry if it fits within the token limits."
# Arrange # Arrange
entry = f""" entry = """
# Heading 1 # Heading 1
body line 1 body line 1
## Subheading 1.1 ## Subheading 1.1
@@ -180,7 +180,7 @@ body line 1.1
def test_parse_markdown_entry_with_children_as_single_entry_if_small(tmp_path): def test_parse_markdown_entry_with_children_as_single_entry_if_small(tmp_path):
"Parse markdown entry with child headings as single entry if it fits within the tokens limits." "Parse markdown entry with child headings as single entry if it fits within the tokens limits."
# Arrange # Arrange
entry = f""" entry = """
# Heading 1 # Heading 1
body line 1 body line 1
## Subheading 1.1 ## Subheading 1.1
@@ -201,13 +201,13 @@ longer body line 2.1
# Assert # Assert
assert len(entries) == 2 assert len(entries) == 2
assert len(entries[1]) == 3 assert len(entries[1]) == 3
assert ( assert entries[1][0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1", (
entries[1][0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1" "First entry includes children headings"
), "First entry includes children headings" )
assert entries[1][1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings" assert entries[1][1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings"
assert ( assert entries[1][2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n", (
entries[1][2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n" "Third entry is second entries child heading"
), "Third entry is second entries child heading" )
def test_line_number_tracking_in_recursive_split(): def test_line_number_tracking_in_recursive_split():
@@ -252,14 +252,16 @@ def test_line_number_tracking_in_recursive_split():
assert entry.uri is not None, f"Entry '{entry}' has a None URI." assert entry.uri is not None, f"Entry '{entry}' has a None URI."
assert match is not None, f"URI format is incorrect: {entry.uri}" assert match is not None, f"URI format is incorrect: {entry.uri}"
assert ( assert filepath_from_uri == markdown_file_path, (
filepath_from_uri == markdown_file_path f"File path in URI '{filepath_from_uri}' does not match expected '{markdown_file_path}'"
), f"File path in URI '{filepath_from_uri}' does not match expected '{markdown_file_path}'" )
# Ensure the first non-heading line in the compiled entry matches the line in the file # Ensure the first non-heading line in the compiled entry matches the line in the file
assert ( assert (
cleaned_first_entry_line in line_in_file.strip() or cleaned_first_entry_line in next_line_in_file.strip() cleaned_first_entry_line in line_in_file.strip() or cleaned_first_entry_line in next_line_in_file.strip()
), f"First non-heading line '{cleaned_first_entry_line}' in {entry.raw} does not match line {line_number_from_uri} in file: '{line_in_file}' or next line '{next_line_in_file}'" ), (
f"First non-heading line '{cleaned_first_entry_line}' in {entry.raw} does not match line {line_number_from_uri} in file: '{line_in_file}' or next line '{next_line_in_file}'"
)
# Helper Functions # Helper Functions

View File

@@ -343,12 +343,12 @@ Expenses:Food:Dining 10.00 USD""",
"file": "Ledger.org", "file": "Ledger.org",
}, },
{ {
"compiled": f"""2020-04-01 "SuperMercado" "Bananas" "compiled": """2020-04-01 "SuperMercado" "Bananas"
Expenses:Food:Groceries 10.00 USD""", Expenses:Food:Groceries 10.00 USD""",
"file": "Ledger.org", "file": "Ledger.org",
}, },
{ {
"compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner" "compiled": """2020-01-01 "Naco Taco" "Burittos for Dinner"
Expenses:Food:Dining 10.00 USD""", Expenses:Food:Dining 10.00 USD""",
"file": "Ledger.org", "file": "Ledger.org",
}, },
@@ -389,12 +389,12 @@ Expenses:Food:Dining 10.00 USD""",
"file": "Ledger.md", "file": "Ledger.md",
}, },
{ {
"compiled": f"""2020-04-01 "SuperMercado" "Bananas" "compiled": """2020-04-01 "SuperMercado" "Bananas"
Expenses:Food:Groceries 10.00 USD""", Expenses:Food:Groceries 10.00 USD""",
"file": "Ledger.md", "file": "Ledger.md",
}, },
{ {
"compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner" "compiled": """2020-01-01 "Naco Taco" "Burittos for Dinner"
Expenses:Food:Dining 10.00 USD""", Expenses:Food:Dining 10.00 USD""",
"file": "Ledger.md", "file": "Ledger.md",
}, },
@@ -452,17 +452,17 @@ async def test_ask_for_clarification_if_not_enough_context_in_question():
# Arrange # Arrange
context = [ context = [
{ {
"compiled": f"""# Ramya "compiled": """# Ramya
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""", My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""",
"file": "Family.md", "file": "Family.md",
}, },
{ {
"compiled": f"""# Fang "compiled": """# Fang
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""", My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""",
"file": "Family.md", "file": "Family.md",
}, },
{ {
"compiled": f"""# Aiyla "compiled": """# Aiyla
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""", My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""",
"file": "Family.md", "file": "Family.md",
}, },
@@ -497,9 +497,9 @@ async def test_agent_prompt_should_be_used(openai_agent):
"Chat actor should ask be tuned to think like an accountant based on the agent definition" "Chat actor should ask be tuned to think like an accountant based on the agent definition"
# Arrange # Arrange
context = [ context = [
{"compiled": f"""I went to the store and bought some bananas for 2.20""", "file": "Ledger.md"}, {"compiled": """I went to the store and bought some bananas for 2.20""", "file": "Ledger.md"},
{"compiled": f"""I went to the store and bought some apples for 1.30""", "file": "Ledger.md"}, {"compiled": """I went to the store and bought some apples for 1.30""", "file": "Ledger.md"},
{"compiled": f"""I went to the store and bought some oranges for 6.00""", "file": "Ledger.md"}, {"compiled": """I went to the store and bought some oranges for 6.00""", "file": "Ledger.md"},
] ]
expected_responses = ["9.50", "9.5"] expected_responses = ["9.50", "9.5"]
@@ -539,13 +539,13 @@ async def test_websearch_with_operators(chat_client, default_user2):
responses = await generate_online_subqueries(user_query, [], None, default_user2) responses = await generate_online_subqueries(user_query, [], None, default_user2)
# Assert # Assert
assert any( assert any(["reddit.com/r/worldnews" in response for response in responses]), (
["reddit.com/r/worldnews" in response for response in responses] "Expected a search query to include site:reddit.com but got: " + str(responses)
), "Expected a search query to include site:reddit.com but got: " + str(responses) )
assert any( assert any(["site:reddit.com" in response for response in responses]), (
["site:reddit.com" in response for response in responses] "Expected a search query to include site:reddit.com but got: " + str(responses)
), "Expected a search query to include site:reddit.com but got: " + str(responses) )
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@@ -559,9 +559,9 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client, default_u
responses = await generate_online_subqueries(user_query, [], None, default_user2) responses = await generate_online_subqueries(user_query, [], None, default_user2)
# Assert # Assert
assert any( assert any(["site:khoj.dev" in response for response in responses]), (
["site:khoj.dev" in response for response in responses] "Expected search query to include site:khoj.dev but got: " + str(responses)
), "Expected search query to include site:khoj.dev but got: " + str(responses) )
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@@ -693,9 +693,9 @@ def test_infer_task_scheduling_request(
for expected_q in expected_qs: for expected_q in expected_qs:
assert expected_q in inferred_query, f"Expected fragment {expected_q} in query: {inferred_query}" assert expected_q in inferred_query, f"Expected fragment {expected_q} in query: {inferred_query}"
for unexpected_q in unexpected_qs: for unexpected_q in unexpected_qs:
assert ( assert unexpected_q not in inferred_query, (
unexpected_q not in inferred_query f"Did not expect fragment '{unexpected_q}' in query: '{inferred_query}'"
), f"Did not expect fragment '{unexpected_q}' in query: '{inferred_query}'" )
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------

View File

@@ -33,7 +33,7 @@ def create_conversation(message_list, user, agent=None):
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_chat_with_no_chat_history_or_retrieved_content(chat_client): def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# Act # Act
response = chat_client.post(f"/api/chat", json={"q": "Hello, my name is Testatron. Who are you?"}) response = chat_client.post("/api/chat", json={"q": "Hello, my name is Testatron. Who are you?"})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
@@ -50,7 +50,7 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
def test_chat_with_online_content(chat_client): def test_chat_with_online_content(chat_client):
# Act # Act
q = "/online give me the link to paul graham's essay how to do great work" q = "/online give me the link to paul graham's essay how to do great work"
response = chat_client.post(f"/api/chat?", json={"q": q}) response = chat_client.post("/api/chat?", json={"q": q})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
@@ -59,9 +59,9 @@ def test_chat_with_online_content(chat_client):
"paulgraham.com/hwh.html", "paulgraham.com/hwh.html",
] ]
assert response.status_code == 200 assert response.status_code == 200
assert any( assert any([expected_response in response_message for expected_response in expected_responses]), (
[expected_response in response_message for expected_response in expected_responses] f"Expected links: {expected_responses}. Actual response: {response_message}"
), f"Expected links: {expected_responses}. Actual response: {response_message}" )
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@@ -70,15 +70,15 @@ def test_chat_with_online_content(chat_client):
def test_chat_with_online_webpage_content(chat_client): def test_chat_with_online_webpage_content(chat_client):
# Act # Act
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?" q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
response = chat_client.post(f"/api/chat", json={"q": q}) response = chat_client.post("/api/chat", json={"q": q})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
expected_responses = ["185", "1871", "horse"] expected_responses = ["185", "1871", "horse"]
assert response.status_code == 200 assert response.status_code == 200
assert any( assert any([expected_response in response_message for expected_response in expected_responses]), (
[expected_response in response_message for expected_response in expected_responses] f"Expected links: {expected_responses}. Actual response: {response_message}"
), f"Expected links: {expected_responses}. Actual response: {response_message}" )
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@@ -93,7 +93,7 @@ def test_answer_from_chat_history(chat_client, default_user2: KhojUser):
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client.post(f"/api/chat", json={"q": "What is my name?"}) response = chat_client.post("/api/chat", json={"q": "What is my name?"})
response_message = response.content.decode("utf-8") response_message = response.content.decode("utf-8")
# Assert # Assert
@@ -120,7 +120,7 @@ def test_answer_from_currently_retrieved_content(chat_client, default_user2: Kho
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client.post(f"/api/chat", json={"q": "Where was Xi Li born?"}) response = chat_client.post("/api/chat", json={"q": "Where was Xi Li born?"})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
@@ -144,7 +144,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_n
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client_no_background.post(f"/api/chat", json={"q": "Where was I born?"}) response = chat_client_no_background.post("/api/chat", json={"q": "Where was I born?"})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
@@ -167,7 +167,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, d
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client.post(f"/api/chat", json={"q": "Where was I born?"}) response = chat_client.post("/api/chat", json={"q": "Where was I born?"})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
@@ -192,7 +192,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client.post(f"/api/chat", json={"q": "Where was I born?"}) response = chat_client.post("/api/chat", json={"q": "Where was I born?"})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
@@ -222,7 +222,7 @@ def test_answer_using_general_command(chat_client, default_user2: KhojUser):
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client.post(f"/api/chat", json={"q": query, "stream": True}) response = chat_client.post("/api/chat", json={"q": query, "stream": True})
response_message = response.content.decode("utf-8") response_message = response.content.decode("utf-8")
# Assert # Assert
@@ -240,7 +240,7 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client, default_
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client.post(f"/api/chat", json={"q": query}) response = chat_client.post("/api/chat", json={"q": query})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
@@ -258,7 +258,7 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client_no_background.post(f"/api/chat", json={"q": query}) response = chat_client_no_background.post("/api/chat", json={"q": query})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
@@ -291,7 +291,7 @@ def test_summarize_one_file(chat_client, default_user2: KhojUser):
json={"filename": summarization_file, "conversation_id": str(conversation.id)}, json={"filename": summarization_file, "conversation_id": str(conversation.id)},
) )
query = "/summarize" query = "/summarize"
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
assert response_message != "" assert response_message != ""
@@ -322,7 +322,7 @@ def test_summarize_extra_text(chat_client, default_user2: KhojUser):
json={"filename": summarization_file, "conversation_id": str(conversation.id)}, json={"filename": summarization_file, "conversation_id": str(conversation.id)},
) )
query = "/summarize tell me about Xiu" query = "/summarize tell me about Xiu"
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
assert response_message != "" assert response_message != ""
@@ -349,7 +349,7 @@ def test_summarize_multiple_files(chat_client, default_user2: KhojUser):
) )
query = "/summarize" query = "/summarize"
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
@@ -365,7 +365,7 @@ def test_summarize_no_files(chat_client, default_user2: KhojUser):
# Act # Act
query = "/summarize" query = "/summarize"
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
@@ -400,11 +400,11 @@ def test_summarize_different_conversation(chat_client, default_user2: KhojUser):
# Act # Act
query = "/summarize" query = "/summarize"
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation2.id)}) response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation2.id)})
response_message_conv2 = response.json()["response"] response_message_conv2 = response.json()["response"]
# now make sure that the file filter is still in conversation 1 # now make sure that the file filter is still in conversation 1
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation1.id)}) response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation1.id)})
response_message_conv1 = response.json()["response"] response_message_conv1 = response.json()["response"]
# Assert # Assert
@@ -430,7 +430,7 @@ def test_summarize_nonexistant_file(chat_client, default_user2: KhojUser):
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)}, json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
) )
query = urllib.parse.quote("/summarize") query = urllib.parse.quote("/summarize")
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left." assert response_message == "No files selected for summarization. Please add files using the section on the left."
@@ -462,7 +462,7 @@ def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_confi
# Act # Act
query = "/summarize" query = "/summarize"
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
@@ -477,7 +477,7 @@ def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_confi
def test_answer_requires_current_date_awareness(chat_client): def test_answer_requires_current_date_awareness(chat_client):
"Chat actor should be able to answer questions relative to current date using provided notes" "Chat actor should be able to answer questions relative to current date using provided notes"
# Act # Act
response = chat_client.post(f"/api/chat", json={"q": "Where did I have lunch today?", "stream": True}) response = chat_client.post("/api/chat", json={"q": "Where did I have lunch today?", "stream": True})
response_message = response.content.decode("utf-8") response_message = response.content.decode("utf-8")
# Assert # Assert
@@ -496,7 +496,7 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_clien
"Chat director should be able to answer questions that require date aware aggregation across multiple notes" "Chat director should be able to answer questions that require date aware aggregation across multiple notes"
# Act # Act
query = "How much did I spend on dining this year?" query = "How much did I spend on dining this year?"
response = chat_client.post(f"/api/chat", json={"q": query}) response = chat_client.post("/api/chat", json={"q": query})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
@@ -518,7 +518,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
# Act # Act
query = "Write a haiku about unit testing. Do not say anything else." query = "Write a haiku about unit testing. Do not say anything else."
response = chat_client.post(f"/api/chat", json={"q": query}) response = chat_client.post("/api/chat", json={"q": query})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
@@ -536,7 +536,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background): def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
# Act # Act
query = "What is the name of Namitas older son?" query = "What is the name of Namitas older son?"
response = chat_client_no_background.post(f"/api/chat", json={"q": query}) response = chat_client_no_background.post("/api/chat", json={"q": query})
response_message = response.json()["response"].lower() response_message = response.json()["response"].lower()
# Assert # Assert
@@ -571,7 +571,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user
# Act # Act
query = "What is my name?" query = "What is my name?"
response = chat_client.post(f"/api/chat", json={"q": query}) response = chat_client.post("/api/chat", json={"q": query})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
@@ -604,9 +604,7 @@ def test_answer_in_chat_history_by_conversation_id(chat_client, default_user2: K
# Act # Act
query = "/general What is my favorite color?" query = "/general What is my favorite color?"
response = chat_client.post( response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id), "stream": True})
f"/api/chat", json={"q": query, "conversation_id": str(conversation.id), "stream": True}
)
response_message = response.content.decode("utf-8") response_message = response.content.decode("utf-8")
# Assert # Assert
@@ -639,7 +637,7 @@ def test_answer_in_chat_history_by_conversation_id_with_agent(
# Act # Act
query = "/general What did I buy for breakfast?" query = "/general What did I buy for breakfast?"
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert that agent only responds with the summary of spending # Assert that agent only responds with the summary of spending
@@ -657,7 +655,7 @@ def test_answer_requires_multiple_independent_searches(chat_client):
"Chat director should be able to answer by doing multiple independent searches for required information" "Chat director should be able to answer by doing multiple independent searches for required information"
# Act # Act
query = "Is Xi Li older than Namita? Just say the older persons full name" query = "Is Xi Li older than Namita? Just say the older persons full name"
response = chat_client.post(f"/api/chat", json={"q": query}) response = chat_client.post("/api/chat", json={"q": query})
response_message = response.json()["response"].lower() response_message = response.json()["response"].lower()
# Assert # Assert
@@ -681,7 +679,7 @@ def test_answer_using_file_filter(chat_client):
query = ( query = (
'Is Xi Li older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"' 'Is Xi Li older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"'
) )
response = chat_client.post(f"/api/chat", json={"q": query}) response = chat_client.post("/api/chat", json={"q": query})
response_message = response.json()["response"].lower() response_message = response.json()["response"].lower()
# Assert # Assert

View File

@@ -12,7 +12,7 @@ def test_configure_indexing_heading_only_entries(tmp_path):
"""Ensure entries with empty body are ignored, unless explicitly configured to index heading entries. """Ensure entries with empty body are ignored, unless explicitly configured to index heading entries.
Property drawers not considered Body. Ignore control characters for evaluating if Body empty.""" Property drawers not considered Body. Ignore control characters for evaluating if Body empty."""
# Arrange # Arrange
entry = f"""*** Heading entry = """*** Heading
:PROPERTIES: :PROPERTIES:
:ID: 42-42-42 :ID: 42-42-42
:END: :END:
@@ -74,7 +74,7 @@ def test_entry_split_when_exceeds_max_tokens():
"Ensure entries with compiled words exceeding max_tokens are split." "Ensure entries with compiled words exceeding max_tokens are split."
# Arrange # Arrange
tmp_path = "/tmp/test.org" tmp_path = "/tmp/test.org"
entry = f"""*** Heading entry = """*** Heading
\t\r \t\r
Body Line Body Line
""" """
@@ -99,7 +99,7 @@ def test_entry_split_when_exceeds_max_tokens():
def test_entry_split_drops_large_words(): def test_entry_split_drops_large_words():
"Ensure entries drops words larger than specified max word length from compiled version." "Ensure entries drops words larger than specified max word length from compiled version."
# Arrange # Arrange
entry_text = f"""First Line entry_text = """First Line
dog=1\n\r\t dog=1\n\r\t
cat=10 cat=10
car=4 car=4
@@ -124,7 +124,7 @@ book=2
def test_parse_org_file_into_single_entry_if_small(tmp_path): def test_parse_org_file_into_single_entry_if_small(tmp_path):
"Parse org file into single entry if it fits within the token limits." "Parse org file into single entry if it fits within the token limits."
# Arrange # Arrange
original_entry = f""" original_entry = """
* Heading 1 * Heading 1
body line 1 body line 1
** Subheading 1.1 ** Subheading 1.1
@@ -133,7 +133,7 @@ body line 1.1
data = { data = {
f"{tmp_path}": original_entry, f"{tmp_path}": original_entry,
} }
expected_entry = f""" expected_entry = """
* Heading 1 * Heading 1
body line 1 body line 1
@@ -155,7 +155,7 @@ body line 1.1
def test_parse_org_entry_with_children_as_single_entry_if_small(tmp_path): def test_parse_org_entry_with_children_as_single_entry_if_small(tmp_path):
"Parse org entry with child headings as single entry only if it fits within the tokens limits." "Parse org entry with child headings as single entry only if it fits within the tokens limits."
# Arrange # Arrange
entry = f""" entry = """
* Heading 1 * Heading 1
body line 1 body line 1
** Subheading 1.1 ** Subheading 1.1
@@ -205,7 +205,7 @@ longer body line 2.1
def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path): def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path):
"Parse org sibling entries as separate entries only if it fits within the tokens limits." "Parse org sibling entries as separate entries only if it fits within the tokens limits."
# Arrange # Arrange
entry = f""" entry = """
* Heading 1 * Heading 1
body line 1 body line 1
** Subheading 1.1 ** Subheading 1.1
@@ -267,7 +267,7 @@ body line 3.1
def test_entry_with_body_to_entry(tmp_path): def test_entry_with_body_to_entry(tmp_path):
"Ensure entries with valid body text are loaded." "Ensure entries with valid body text are loaded."
# Arrange # Arrange
entry = f"""*** Heading entry = """*** Heading
:PROPERTIES: :PROPERTIES:
:ID: 42-42-42 :ID: 42-42-42
:END: :END:
@@ -290,7 +290,7 @@ def test_entry_with_body_to_entry(tmp_path):
def test_file_with_entry_after_intro_text_to_entry(tmp_path): def test_file_with_entry_after_intro_text_to_entry(tmp_path):
"Ensure intro text before any headings is indexed." "Ensure intro text before any headings is indexed."
# Arrange # Arrange
entry = f""" entry = """
Intro text Intro text
* Entry Heading * Entry Heading
@@ -312,7 +312,7 @@ Intro text
def test_file_with_no_headings_to_entry(tmp_path): def test_file_with_no_headings_to_entry(tmp_path):
"Ensure files with no heading, only body text are loaded." "Ensure files with no heading, only body text are loaded."
# Arrange # Arrange
entry = f""" entry = """
- Bullet point 1 - Bullet point 1
- Bullet point 2 - Bullet point 2
""" """
@@ -332,7 +332,7 @@ def test_file_with_no_headings_to_entry(tmp_path):
def test_extract_entries_with_different_level_headings(tmp_path): def test_extract_entries_with_different_level_headings(tmp_path):
"Extract org entries with different level headings." "Extract org entries with different level headings."
# Arrange # Arrange
entry = f""" entry = """
* Heading 1 * Heading 1
** Sub-Heading 1.1 ** Sub-Heading 1.1
* Heading 2 * Heading 2
@@ -396,14 +396,16 @@ def test_line_number_tracking_in_recursive_split():
assert entry.uri is not None, f"Entry '{entry}' has a None URI." assert entry.uri is not None, f"Entry '{entry}' has a None URI."
assert match is not None, f"URI format is incorrect: {entry.uri}" assert match is not None, f"URI format is incorrect: {entry.uri}"
assert ( assert filepath_from_uri == org_file_path, (
filepath_from_uri == org_file_path f"File path in URI '{filepath_from_uri}' does not match expected '{org_file_path}'"
), f"File path in URI '{filepath_from_uri}' does not match expected '{org_file_path}'" )
# Ensure the first non-heading line in the compiled entry matches the line in the file # Ensure the first non-heading line in the compiled entry matches the line in the file
assert ( assert (
cleaned_first_entry_line in line_in_file.strip() or cleaned_first_entry_line in next_line_in_file.strip() cleaned_first_entry_line in line_in_file.strip() or cleaned_first_entry_line in next_line_in_file.strip()
), f"First non-heading line '{cleaned_first_entry_line}' in {entry.raw} does not match line {line_number_from_uri} in file: '{line_in_file}' or next line '{next_line_in_file}'" ), (
f"First non-heading line '{cleaned_first_entry_line}' in {entry.raw} does not match line {line_number_from_uri} in file: '{line_in_file}' or next line '{next_line_in_file}'"
)
# Helper Functions # Helper Functions

View File

@@ -8,7 +8,7 @@ from khoj.processor.content.org_mode import orgnode
def test_parse_entry_with_no_headings(tmp_path): def test_parse_entry_with_no_headings(tmp_path):
"Test parsing of entry with minimal fields" "Test parsing of entry with minimal fields"
# Arrange # Arrange
entry = f"""Body Line 1""" entry = """Body Line 1"""
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
# Act # Act
@@ -30,7 +30,7 @@ def test_parse_entry_with_no_headings(tmp_path):
def test_parse_minimal_entry(tmp_path): def test_parse_minimal_entry(tmp_path):
"Test parsing of entry with minimal fields" "Test parsing of entry with minimal fields"
# Arrange # Arrange
entry = f""" entry = """
* Heading * Heading
Body Line 1""" Body Line 1"""
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
@@ -54,7 +54,7 @@ Body Line 1"""
def test_parse_complete_entry(tmp_path): def test_parse_complete_entry(tmp_path):
"Test parsing of entry with all important fields" "Test parsing of entry with all important fields"
# Arrange # Arrange
entry = f""" entry = """
*** DONE [#A] Heading :Tag1:TAG2:tag3: *** DONE [#A] Heading :Tag1:TAG2:tag3:
CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun> CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun>
:PROPERTIES: :PROPERTIES:
@@ -89,7 +89,7 @@ Body Line 2"""
def test_render_entry_with_property_drawer_and_empty_body(tmp_path): def test_render_entry_with_property_drawer_and_empty_body(tmp_path):
"Render heading entry with property drawer" "Render heading entry with property drawer"
# Arrange # Arrange
entry_to_render = f""" entry_to_render = """
*** [#A] Heading1 :tag1: *** [#A] Heading1 :tag1:
:PROPERTIES: :PROPERTIES:
:ID: 111-111-111-1111-1111 :ID: 111-111-111-1111-1111
@@ -116,7 +116,7 @@ def test_render_entry_with_property_drawer_and_empty_body(tmp_path):
def test_all_links_to_entry_rendered(tmp_path): def test_all_links_to_entry_rendered(tmp_path):
"Ensure all links to entry rendered in property drawer from entry" "Ensure all links to entry rendered in property drawer from entry"
# Arrange # Arrange
entry = f""" entry = """
*** [#A] Heading :tag1: *** [#A] Heading :tag1:
:PROPERTIES: :PROPERTIES:
:ID: 123-456-789-4234-1231 :ID: 123-456-789-4234-1231
@@ -133,7 +133,7 @@ Body Line 2
# Assert # Assert
# SOURCE link rendered with Heading # SOURCE link rendered with Heading
# ID link rendered with ID # ID link rendered with ID
assert f":ID: id:123-456-789-4234-1231" in f"{entries[0]}" assert ":ID: id:123-456-789-4234-1231" in f"{entries[0]}"
# LINE link rendered with line number # LINE link rendered with line number
assert f":LINE: file://{orgfile}#line=2" in f"{entries[0]}" assert f":LINE: file://{orgfile}#line=2" in f"{entries[0]}"
# LINE link rendered with line number # LINE link rendered with line number
@@ -144,7 +144,7 @@ Body Line 2
def test_parse_multiple_entries(tmp_path): def test_parse_multiple_entries(tmp_path):
"Test parsing of multiple entries" "Test parsing of multiple entries"
# Arrange # Arrange
content = f""" content = """
*** FAILED [#A] Heading1 :tag1: *** FAILED [#A] Heading1 :tag1:
CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun> CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun>
:PROPERTIES: :PROPERTIES:
@@ -176,12 +176,12 @@ Body 2
# Assert # Assert
assert len(entries) == 2 assert len(entries) == 2
for index, entry in enumerate(entries): for index, entry in enumerate(entries):
assert entry.heading == f"Heading{index+1}" assert entry.heading == f"Heading{index + 1}"
assert entry.todo == "FAILED" if index == 0 else "CANCELLED" assert entry.todo == "FAILED" if index == 0 else "CANCELLED"
assert entry.tags == [f"tag{index+1}"] assert entry.tags == [f"tag{index + 1}"]
assert entry.body == f"- Clocked Log {index+1}\n\nBody {index+1}\n\n" assert entry.body == f"- Clocked Log {index + 1}\n\nBody {index + 1}\n\n"
assert entry.priority == "A" assert entry.priority == "A"
assert entry.Property("ID") == f"id:123-456-789-4234-000{index+1}" assert entry.Property("ID") == f"id:123-456-789-4234-000{index + 1}"
assert entry.closed == datetime.date(1984, 4, index + 1) assert entry.closed == datetime.date(1984, 4, index + 1)
assert entry.scheduled == datetime.date(1984, 4, index + 1) assert entry.scheduled == datetime.date(1984, 4, index + 1)
assert entry.deadline == datetime.date(1984, 4, index + 1) assert entry.deadline == datetime.date(1984, 4, index + 1)
@@ -194,7 +194,7 @@ Body 2
def test_parse_entry_with_empty_title(tmp_path): def test_parse_entry_with_empty_title(tmp_path):
"Test parsing of entry with minimal fields" "Test parsing of entry with minimal fields"
# Arrange # Arrange
entry = f"""#+TITLE: entry = """#+TITLE:
Body Line 1""" Body Line 1"""
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
@@ -217,7 +217,7 @@ Body Line 1"""
def test_parse_entry_with_title_and_no_headings(tmp_path): def test_parse_entry_with_title_and_no_headings(tmp_path):
"Test parsing of entry with minimal fields" "Test parsing of entry with minimal fields"
# Arrange # Arrange
entry = f"""#+TITLE: test entry = """#+TITLE: test
Body Line 1""" Body Line 1"""
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
@@ -241,7 +241,7 @@ Body Line 1"""
def test_parse_entry_with_multiple_titles_and_no_headings(tmp_path): def test_parse_entry_with_multiple_titles_and_no_headings(tmp_path):
"Test parsing of entry with minimal fields" "Test parsing of entry with minimal fields"
# Arrange # Arrange
entry = f"""#+TITLE: title1 entry = """#+TITLE: title1
Body Line 1 Body Line 1
#+TITLE: title2 """ #+TITLE: title2 """
orgfile = create_file(tmp_path, entry) orgfile = create_file(tmp_path, entry)
@@ -266,7 +266,7 @@ Body Line 1
def test_parse_org_with_intro_text_before_heading(tmp_path): def test_parse_org_with_intro_text_before_heading(tmp_path):
"Test parsing of org file with intro text before heading" "Test parsing of org file with intro text before heading"
# Arrange # Arrange
body = f"""#+TITLE: Title body = """#+TITLE: Title
intro body intro body
* Entry Heading * Entry Heading
entry body entry body
@@ -290,7 +290,7 @@ entry body
def test_parse_org_with_intro_text_multiple_titles_and_heading(tmp_path): def test_parse_org_with_intro_text_multiple_titles_and_heading(tmp_path):
"Test parsing of org file with intro text, multiple titles and heading entry" "Test parsing of org file with intro text, multiple titles and heading entry"
# Arrange # Arrange
body = f"""#+TITLE: Title1 body = """#+TITLE: Title1
intro body intro body
* Entry Heading * Entry Heading
entry body entry body
@@ -314,7 +314,7 @@ entry body
def test_parse_org_with_single_ancestor_heading(tmp_path): def test_parse_org_with_single_ancestor_heading(tmp_path):
"Parse org entries with parent headings context" "Parse org entries with parent headings context"
# Arrange # Arrange
body = f""" body = """
* Heading 1 * Heading 1
body 1 body 1
** Sub Heading 1 ** Sub Heading 1
@@ -336,7 +336,7 @@ body 1
def test_parse_org_with_multiple_ancestor_headings(tmp_path): def test_parse_org_with_multiple_ancestor_headings(tmp_path):
"Parse org entries with parent headings context" "Parse org entries with parent headings context"
# Arrange # Arrange
body = f""" body = """
* Heading 1 * Heading 1
body 1 body 1
** Sub Heading 1 ** Sub Heading 1
@@ -362,7 +362,7 @@ sub sub body 1
def test_parse_org_with_multiple_ancestor_headings_of_siblings(tmp_path): def test_parse_org_with_multiple_ancestor_headings_of_siblings(tmp_path):
"Parse org entries with parent headings context" "Parse org entries with parent headings context"
# Arrange # Arrange
body = f""" body = """
* Heading 1 * Heading 1
body 1 body 1
** Sub Heading 1 ** Sub Heading 1

View File

@@ -7,7 +7,7 @@ from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEnt
def test_plaintext_file(): def test_plaintext_file():
"Convert files with no heading to jsonl." "Convert files with no heading to jsonl."
# Arrange # Arrange
raw_entry = f""" raw_entry = """
Hi, I am a plaintext file and I have some plaintext words. Hi, I am a plaintext file and I have some plaintext words.
""" """
plaintextfile = "test.txt" plaintextfile = "test.txt"

View File

@@ -145,9 +145,9 @@ def test_entry_chunking_by_max_tokens(tmp_path, search_config, default_user: Kho
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user) text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
# Assert # Assert
assert ( assert "Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message, (
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message "new entry not split by max tokens"
), "new entry not split by max tokens" )
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@@ -198,9 +198,9 @@ conda activate khoj
) )
# Assert # Assert
assert ( assert "Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message, (
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message "new entry not split by max tokens"
), "new entry not split by max tokens" )
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------