diff --git a/src/khoj/app/urls.py b/src/khoj/app/urls.py index 2754270c..d5c2c1f5 100644 --- a/src/khoj/app/urls.py +++ b/src/khoj/app/urls.py @@ -14,6 +14,7 @@ Including another URLconf 1. Import the include() function: from django.urls import include, path 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) """ + from django.contrib import admin from django.contrib.staticfiles.urls import staticfiles_urlpatterns from django.urls import path diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 6d92b8e9..e1711025 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1910,9 +1910,9 @@ class EntryAdapters: owner_filter = Q() - if user != None: + if user is not None: owner_filter = Q(user=user) - if agent != None: + if agent is not None: owner_filter |= Q(agent=agent) if owner_filter == Q(): @@ -1972,9 +1972,9 @@ class EntryAdapters: ): owner_filter = Q() - if user != None: + if user is not None: owner_filter = Q(user=user) - if agent != None: + if agent is not None: owner_filter |= Q(agent=agent) if owner_filter == Q(): diff --git a/src/khoj/database/management/commands/delete_orphaned_fileobjects.py b/src/khoj/database/management/commands/delete_orphaned_fileobjects.py index f7efa1dd..99d45c6f 100644 --- a/src/khoj/database/management/commands/delete_orphaned_fileobjects.py +++ b/src/khoj/database/management/commands/delete_orphaned_fileobjects.py @@ -1,5 +1,4 @@ from django.core.management.base import BaseCommand -from django.db import transaction from django.db.models import Exists, OuterRef from khoj.database.models import Entry, FileObject diff --git a/src/khoj/database/migrations/0064_remove_conversation_temp_id_alter_conversation_id.py b/src/khoj/database/migrations/0064_remove_conversation_temp_id_alter_conversation_id.py index 16d76d10..acc5dc4c 100644 --- a/src/khoj/database/migrations/0064_remove_conversation_temp_id_alter_conversation_id.py +++ b/src/khoj/database/migrations/0064_remove_conversation_temp_id_alter_conversation_id.py @@ -41,7 +41,7 @@ def update_conversation_id_in_job_state(apps, schema_editor): job.save() except Conversation.DoesNotExist: pass - except LookupError as e: + except LookupError: pass diff --git a/src/khoj/database/migrations/0075_migrate_generated_assets_and_validate.py b/src/khoj/database/migrations/0075_migrate_generated_assets_and_validate.py index 40c74ebf..3d3c4a06 100644 --- a/src/khoj/database/migrations/0075_migrate_generated_assets_and_validate.py +++ b/src/khoj/database/migrations/0075_migrate_generated_assets_and_validate.py @@ -1,6 +1,6 @@ # Made manually by sabaimran for use by Django 5.0.9 on 2024-12-01 16:59 -from django.db import migrations, models +from django.db import migrations # This script was written alongside when Pydantic validation was added to the Conversation conversation_log field. diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 1ed58572..e0874f59 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -551,12 +551,12 @@ class TextToImageModelConfig(DbBaseModel): error = {} if self.model_type == self.ModelType.OPENAI: if self.api_key and self.ai_model_api: - error[ - "api_key" - ] = "Both API key and AI Model API cannot be set for OpenAI models. Please set only one of them." - error[ - "ai_model_api" - ] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them." + error["api_key"] = ( + "Both API key and AI Model API cannot be set for OpenAI models. Please set only one of them." + ) + error["ai_model_api"] = ( + "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them." + ) if self.model_type != self.ModelType.OPENAI and self.model_type != self.ModelType.GOOGLE: if not self.api_key: error["api_key"] = "The API key field must be set for non OpenAI, non Google models." diff --git a/src/khoj/database/tests.py b/src/khoj/database/tests.py index 7ce503c2..a39b155a 100644 --- a/src/khoj/database/tests.py +++ b/src/khoj/database/tests.py @@ -1,3 +1 @@ -from django.test import TestCase - # Create your tests here. diff --git a/src/khoj/main.py b/src/khoj/main.py index f42ae135..d9e96afd 100644 --- a/src/khoj/main.py +++ b/src/khoj/main.py @@ -1,5 +1,5 @@ -""" Main module for Khoj - isort:skip_file +"""Main module for Khoj +isort:skip_file """ from contextlib import redirect_stdout @@ -189,7 +189,7 @@ def run(should_start_server=True): static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static") if not os.path.exists(static_dir): os.mkdir(static_dir) - app.mount(f"/static", StaticFiles(directory=static_dir), name=static_dir) + app.mount("/static", StaticFiles(directory=static_dir), name=static_dir) # Configure Middleware configure_middleware(app, state.ssl_config) diff --git a/src/khoj/manage.py b/src/khoj/manage.py index 9b8f4b27..aaf9db17 100755 --- a/src/khoj/manage.py +++ b/src/khoj/manage.py @@ -1,5 +1,6 @@ #!/usr/bin/env python """Django's command-line utility for administrative tasks.""" + import os import sys diff --git a/src/khoj/processor/content/github/github_to_entries.py b/src/khoj/processor/content/github/github_to_entries.py index 63ed50c6..51f8b610 100644 --- a/src/khoj/processor/content/github/github_to_entries.py +++ b/src/khoj/processor/content/github/github_to_entries.py @@ -51,7 +51,7 @@ class GithubToEntries(TextToEntries): def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: if is_none_or_empty(self.config.pat_token): logger.warning( - f"Github PAT token is not set. Private repositories cannot be indexed and lower rate limits apply." + "Github PAT token is not set. Private repositories cannot be indexed and lower rate limits apply." ) current_entries = [] for repo in self.config.repos: @@ -137,7 +137,7 @@ class GithubToEntries(TextToEntries): # Find all markdown files in the repository if item["type"] == "blob" and item["path"].endswith(".md"): # Create URL for each markdown file on Github - url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}' + url_path = f"https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item['path']}" # Add markdown file contents and URL to list markdown_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}] @@ -145,19 +145,19 @@ class GithubToEntries(TextToEntries): # Find all org files in the repository elif item["type"] == "blob" and item["path"].endswith(".org"): # Create URL for each org file on Github - url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}' + url_path = f"https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item['path']}" # Add org file contents and URL to list org_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}] # Find, index remaining non-binary files in the repository elif item["type"] == "blob": - url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}' + url_path = f"https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item['path']}" content_bytes = self.get_file_contents(item["url"], decode=False) content_type, content_str = None, None try: content_type = magika.identify_bytes(content_bytes).output.group - except: + except Exception: logger.error(f"Unable to identify content type of file at {url_path}. Skip indexing it") continue @@ -165,7 +165,7 @@ class GithubToEntries(TextToEntries): if content_type in ["text", "code"]: try: content_str = content_bytes.decode("utf-8") - except: + except Exception: logger.error(f"Unable to decode content of file at {url_path}. Skip indexing it") continue plaintext_files += [{"content": content_str, "path": url_path}] diff --git a/src/khoj/processor/content/images/image_to_entries.py b/src/khoj/processor/content/images/image_to_entries.py index 8be2bd0f..ea896e59 100644 --- a/src/khoj/processor/content/images/image_to_entries.py +++ b/src/khoj/processor/content/images/image_to_entries.py @@ -1,4 +1,3 @@ -import base64 import logging import os from datetime import datetime, timezone diff --git a/src/khoj/processor/content/markdown/markdown_to_entries.py b/src/khoj/processor/content/markdown/markdown_to_entries.py index 43b10431..6be92141 100644 --- a/src/khoj/processor/content/markdown/markdown_to_entries.py +++ b/src/khoj/processor/content/markdown/markdown_to_entries.py @@ -1,6 +1,5 @@ import logging import re -from pathlib import Path from typing import Dict, List, Tuple import urllib3.util @@ -86,7 +85,7 @@ class MarkdownToEntries(TextToEntries): # If content is small or content has no children headings, save it as a single entry if len(TextToEntries.tokenizer(markdown_content_with_ancestry)) <= max_tokens or not re.search( - rf"^#{{{len(ancestry)+1},}}\s", markdown_content, flags=re.MULTILINE + rf"^#{{{len(ancestry) + 1},}}\s", markdown_content, flags=re.MULTILINE ): # Create entry with line number information entry_with_line_info = (markdown_content_with_ancestry, markdown_file, start_line) @@ -160,7 +159,7 @@ class MarkdownToEntries(TextToEntries): calculated_line = start_line if start_line > 0 else 1 # Check if raw_filename is a URL. If so, save it as is. If not, convert it to a Path. - if type(raw_filename) == str and re.search(r"^https?://", raw_filename): + if isinstance(raw_filename, str) and re.search(r"^https?://", raw_filename): # Escape the URL to avoid issues with special characters entry_filename = urllib3.util.parse_url(raw_filename).url uri = entry_filename diff --git a/src/khoj/processor/content/notion/notion_to_entries.py b/src/khoj/processor/content/notion/notion_to_entries.py index 23b96f63..72152d4e 100644 --- a/src/khoj/processor/content/notion/notion_to_entries.py +++ b/src/khoj/processor/content/notion/notion_to_entries.py @@ -91,7 +91,7 @@ class NotionToEntries(TextToEntries): json=self.body_params, ).json() responses.append(result) - if result.get("has_more", False) == False: + if not result.get("has_more", False): break else: self.body_params.update({"start_cursor": result["next_cursor"]}) @@ -118,7 +118,7 @@ class NotionToEntries(TextToEntries): page_id = page["id"] title, content = self.get_page_content(page_id) - if title == None or content == None: + if title is None or content is None: return [] current_entries = [] @@ -126,11 +126,11 @@ class NotionToEntries(TextToEntries): for block in content.get("results", []): block_type = block.get("type") - if block_type == None: + if block_type is None: continue block_data = block[block_type] - if block_data.get("rich_text") == None or len(block_data["rich_text"]) == 0: + if block_data.get("rich_text") is None or len(block_data["rich_text"]) == 0: # There's no text to handle here. continue @@ -179,7 +179,7 @@ class NotionToEntries(TextToEntries): results = children.get("results", []) for child in results: child_type = child.get("type") - if child_type == None: + if child_type is None: continue child_data = child[child_type] if child_data.get("rich_text") and len(child_data["rich_text"]) > 0: diff --git a/src/khoj/processor/content/org_mode/org_to_entries.py b/src/khoj/processor/content/org_mode/org_to_entries.py index 0dfe7674..6ac573c0 100644 --- a/src/khoj/processor/content/org_mode/org_to_entries.py +++ b/src/khoj/processor/content/org_mode/org_to_entries.py @@ -8,7 +8,6 @@ from khoj.database.models import KhojUser from khoj.processor.content.org_mode import orgnode from khoj.processor.content.org_mode.orgnode import Orgnode from khoj.processor.content.text_to_entries import TextToEntries -from khoj.utils import state from khoj.utils.helpers import timer from khoj.utils.rawconfig import Entry @@ -103,7 +102,7 @@ class OrgToEntries(TextToEntries): # If content is small or content has no children headings, save it as a single entry # Note: This is the terminating condition for this recursive function if len(TextToEntries.tokenizer(org_content_with_ancestry)) <= max_tokens or not re.search( - rf"^\*{{{len(ancestry)+1},}}\s", org_content, re.MULTILINE + rf"^\*{{{len(ancestry) + 1},}}\s", org_content, re.MULTILINE ): orgnode_content_with_ancestry = orgnode.makelist( org_content_with_ancestry, org_file, start_line=start_line, ancestry_lines=len(ancestry) @@ -195,7 +194,7 @@ class OrgToEntries(TextToEntries): if not entry_heading and parsed_entry.level > 0: base_level = parsed_entry.level # Indent entry by 1 heading level as ancestry is prepended as top level heading - heading = f"{'*' * (parsed_entry.level-base_level+2)} {todo_str}" if parsed_entry.level > 0 else "" + heading = f"{'*' * (parsed_entry.level - base_level + 2)} {todo_str}" if parsed_entry.level > 0 else "" if parsed_entry.heading: heading += f"{parsed_entry.heading}." @@ -212,10 +211,10 @@ class OrgToEntries(TextToEntries): compiled += f"\t {tags_str}." if parsed_entry.closed: - compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.' + compiled += f"\n Closed on {parsed_entry.closed.strftime('%Y-%m-%d')}." if parsed_entry.scheduled: - compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.' + compiled += f"\n Scheduled for {parsed_entry.scheduled.strftime('%Y-%m-%d')}." if parsed_entry.hasBody: compiled += f"\n {parsed_entry.body}" diff --git a/src/khoj/processor/content/org_mode/orgnode.py b/src/khoj/processor/content/org_mode/orgnode.py index 34bb54f3..bf014430 100644 --- a/src/khoj/processor/content/org_mode/orgnode.py +++ b/src/khoj/processor/content/org_mode/orgnode.py @@ -65,7 +65,7 @@ def makelist(file, filename, start_line: int = 1, ancestry_lines: int = 0) -> Li """ ctr = 0 - if type(file) == str: + if isinstance(file, str): f = file.splitlines() else: f = file @@ -512,11 +512,11 @@ class Orgnode(object): if self._closed or self._scheduled or self._deadline: n = n + indent if self._closed: - n = n + f'CLOSED: [{self._closed.strftime("%Y-%m-%d %a")}] ' + n = n + f"CLOSED: [{self._closed.strftime('%Y-%m-%d %a')}] " if self._scheduled: - n = n + f'SCHEDULED: <{self._scheduled.strftime("%Y-%m-%d %a")}> ' + n = n + f"SCHEDULED: <{self._scheduled.strftime('%Y-%m-%d %a')}> " if self._deadline: - n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> ' + n = n + f"DEADLINE: <{self._deadline.strftime('%Y-%m-%d %a')}> " if self._closed or self._scheduled or self._deadline: n = n + "\n" diff --git a/src/khoj/processor/content/plaintext/plaintext_to_entries.py b/src/khoj/processor/content/plaintext/plaintext_to_entries.py index 64470c08..80d191bb 100644 --- a/src/khoj/processor/content/plaintext/plaintext_to_entries.py +++ b/src/khoj/processor/content/plaintext/plaintext_to_entries.py @@ -1,6 +1,5 @@ import logging import re -from pathlib import Path from typing import Dict, List, Tuple import urllib3 @@ -97,7 +96,7 @@ class PlaintextToEntries(TextToEntries): for parsed_entry in parsed_entries: raw_filename = entry_to_file_map[parsed_entry] # Check if raw_filename is a URL. If so, save it as is. If not, convert it to a Path. - if type(raw_filename) == str and re.search(r"^https?://", raw_filename): + if isinstance(raw_filename, str) and re.search(r"^https?://", raw_filename): # Escape the URL to avoid issues with special characters entry_filename = urllib3.util.parse_url(raw_filename).url else: diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py index 0369d273..959726ea 100644 --- a/src/khoj/processor/content/text_to_entries.py +++ b/src/khoj/processor/content/text_to_entries.py @@ -30,8 +30,7 @@ class TextToEntries(ABC): self.date_filter = DateFilter() @abstractmethod - def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: - ... + def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: ... @staticmethod def hash_func(key: str) -> Callable: diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index f78c420b..92f8cce7 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -194,7 +194,7 @@ def gemini_completion_with_backoff( or not response.candidates[0].content or response.candidates[0].content.parts is None ): - raise ValueError(f"Failed to get response from model.") + raise ValueError("Failed to get response from model.") raw_content = [part.model_dump() for part in response.candidates[0].content.parts] if response.function_calls: function_calls = [ @@ -212,7 +212,7 @@ def gemini_completion_with_backoff( response = None # Handle 429 rate limit errors directly if e.code == 429: - response_text = f"My brain is exhausted. Can you please try again in a bit?" + response_text = "My brain is exhausted. Can you please try again in a bit?" # Log the full error details for debugging logger.error(f"Gemini ClientError: {e.code} {e.status}. Details: {e.details}") # Handle other errors @@ -361,7 +361,7 @@ def handle_gemini_response( # Ensure we have a proper list of candidates if not isinstance(candidates, list): - message = f"\nUnexpected response format. Try again." + message = "\nUnexpected response format. Try again." stopped = True return message, stopped diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 523786fc..55e47c83 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -2,7 +2,6 @@ import json import logging import os from copy import deepcopy -from functools import partial from time import perf_counter from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union from urllib.parse import urlparse @@ -284,9 +283,9 @@ async def chat_completion_with_backoff( if len(system_messages) > 0: first_system_message_index, first_system_message = system_messages[0] first_system_message_content = first_system_message["content"] - formatted_messages[first_system_message_index][ - "content" - ] = f"{first_system_message_content}\nFormatting re-enabled" + formatted_messages[first_system_message_index]["content"] = ( + f"{first_system_message_content}\nFormatting re-enabled" + ) elif is_twitter_reasoning_model(model_name, api_base_url): reasoning_effort = "high" if deepthought else "low" # Grok-4 models do not support reasoning_effort parameter diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 6e60c786..1b94ba28 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -1,7 +1,6 @@ import base64 import json import logging -import math import mimetypes import os import re @@ -18,7 +17,7 @@ import requests import tiktoken import yaml from langchain_core.messages.chat import ChatMessage -from pydantic import BaseModel, ConfigDict, ValidationError, create_model +from pydantic import BaseModel, ConfigDict, ValidationError from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from khoj.database.adapters import ConversationAdapters @@ -47,7 +46,11 @@ from khoj.utils.yaml import yaml_dump logger = logging.getLogger(__name__) try: - from git import Repo + import importlib.util + + git_spec = importlib.util.find_spec("git") + if git_spec is None: + raise ImportError except ImportError: if is_promptrace_enabled(): logger.warning("GitPython not installed. `pip install gitpython` to use prompt tracer.") @@ -294,7 +297,7 @@ def construct_chat_history_for_operator(conversation_history: List[ChatMessageMo if chat.by == "you" and chat.message: content = [{"type": "text", "text": chat.message}] for file in chat.queryFiles or []: - content += [{"type": "text", "text": f'## File: {file["name"]}\n\n{file["content"]}'}] + content += [{"type": "text", "text": f"## File: {file['name']}\n\n{file['content']}"}] user_message = AgentMessage(role="user", content=content) elif chat.by == "khoj" and chat.message: chat_history += [user_message, AgentMessage(role="assistant", content=chat.message)] @@ -311,7 +314,10 @@ def construct_tool_chat_history( If no tool is provided inferred query for all tools used are added. """ chat_history: list = [] - base_extractor: Callable[[ResearchIteration], List[str]] = lambda iteration: [] + + def base_extractor(iteration: ResearchIteration) -> List[str]: + return [] + extract_inferred_query_map: Dict[ConversationCommand, Callable[[ResearchIteration], List[str]]] = { ConversationCommand.SemanticSearchFiles: ( lambda iteration: [c["query"] for c in iteration.context] if iteration.context else [] @@ -498,7 +504,7 @@ async def save_to_conversation_log( logger.info( f""" -Saved Conversation Turn ({db_conversation.id if db_conversation else 'N/A'}): +Saved Conversation Turn ({db_conversation.id if db_conversation else "N/A"}): You ({user.username}): "{q}" Khoj: "{chat_response}" @@ -625,7 +631,7 @@ def generate_chatml_messages_with_context( if not is_none_or_empty(chat.operatorContext): operator_context = chat.operatorContext - operator_content = "\n\n".join([f'## Task: {oc["query"]}\n{oc["response"]}\n' for oc in operator_context]) + operator_content = "\n\n".join([f"## Task: {oc['query']}\n{oc['response']}\n" for oc in operator_context]) message_context += [ { "type": "text", @@ -744,7 +750,7 @@ def get_encoder( else: # as tiktoken doesn't recognize o1 model series yet encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name) - except: + except Exception: encoder = tiktoken.encoding_for_model(default_tokenizer) if state.verbose > 2: logger.debug( @@ -846,9 +852,9 @@ def truncate_messages( total_tokens, _ = count_total_tokens(messages, encoder, system_message) if total_tokens > max_prompt_size: # At this point, a single message with a single content part of type dict should remain - assert ( - len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict) - ), "Expected a single message with a single content part remaining at this point in truncation" + assert len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict), ( + "Expected a single message with a single content part remaining at this point in truncation" + ) # Collate message content into single string to ease truncation part = messages[0].content[0] diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index 0e7b6657..f9159474 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -1,8 +1,6 @@ import logging from typing import List -from urllib.parse import urlparse -import openai import requests import tqdm from sentence_transformers import CrossEncoder, SentenceTransformer diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 164e3991..fbe4d4f0 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -108,12 +108,12 @@ async def text_to_image( if "content_policy_violation" in e.message: logger.error(f"Image Generation blocked by OpenAI: {e}") status_code = e.status_code # type: ignore - message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore + message = "Image generation blocked by OpenAI due to policy violation" # type: ignore yield image_url or image, status_code, message return else: logger.error(f"Image Generation failed with {e}", exc_info=True) - message = f"Image generation failed using OpenAI" # type: ignore + message = "Image generation failed using OpenAI" # type: ignore status_code = e.status_code # type: ignore yield image_url or image, status_code, message return @@ -199,7 +199,7 @@ def generate_image_with_stability( # Call Stability AI API to generate image response = requests.post( - f"https://api.stability.ai/v2beta/stable-image/generate/sd3", + "https://api.stability.ai/v2beta/stable-image/generate/sd3", headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"}, files={"none": ""}, data={ diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index 2b63c40d..0aa12ca4 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -11,7 +11,7 @@ from khoj.processor.conversation.utils import ( OperatorRun, construct_chat_history_for_operator, ) -from khoj.processor.operator.operator_actions import * +from khoj.processor.operator.operator_actions import RequestUserAction from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent from khoj.processor.operator.operator_agent_base import OperatorAgent from khoj.processor.operator.operator_agent_binary import BinaryOperatorAgent @@ -59,7 +59,7 @@ async def operate_environment( if not reasoning_model or not reasoning_model.vision_enabled: reasoning_model = await ConversationAdapters.aget_vision_enabled_config() if not reasoning_model: - raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.") + raise ValueError("No vision enabled chat model found. Configure a vision chat model to operate environment.") # Create conversation history from conversation log chat_history = construct_chat_history_for_operator(conversation_log) diff --git a/src/khoj/processor/operator/grounding_agent.py b/src/khoj/processor/operator/grounding_agent.py index 16f2d510..fbd907b9 100644 --- a/src/khoj/processor/operator/grounding_agent.py +++ b/src/khoj/processor/operator/grounding_agent.py @@ -1,14 +1,27 @@ import json import logging from textwrap import dedent +from typing import List, Optional from openai import AzureOpenAI, OpenAI from openai.types.chat import ChatCompletion, ChatCompletionMessage from khoj.database.models import ChatModel from khoj.processor.conversation.utils import construct_structured_message -from khoj.processor.operator.operator_actions import * -from khoj.processor.operator.operator_agent_base import AgentActResult +from khoj.processor.operator.operator_actions import ( + BackAction, + ClickAction, + DoubleClickAction, + DragAction, + GotoAction, + KeypressAction, + OperatorAction, + Point, + ScreenshotAction, + ScrollAction, + TypeAction, + WaitAction, +) from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState from khoj.utils.helpers import get_chat_usage_metrics diff --git a/src/khoj/processor/operator/grounding_agent_uitars.py b/src/khoj/processor/operator/grounding_agent_uitars.py index 8209778d..7470607e 100644 --- a/src/khoj/processor/operator/grounding_agent_uitars.py +++ b/src/khoj/processor/operator/grounding_agent_uitars.py @@ -18,7 +18,22 @@ from openai import AsyncAzureOpenAI, AsyncOpenAI from openai.types.chat import ChatCompletion from PIL import Image -from khoj.processor.operator.operator_actions import * +from khoj.processor.operator.operator_actions import ( + BackAction, + ClickAction, + DoubleClickAction, + DragAction, + GotoAction, + KeyDownAction, + KeypressAction, + KeyUpAction, + MoveAction, + OperatorAction, + RequestUserAction, + ScrollAction, + TypeAction, + WaitAction, +) from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState from khoj.utils.helpers import get_chat_usage_metrics @@ -122,11 +137,10 @@ class GroundingAgentUitars: ) temperature = self.temperature - top_k = self.top_k try_times = 3 while not parsed_responses: if try_times <= 0: - logger.warning(f"Reach max retry times to fetch response from client, as error flag.") + logger.warning("Reach max retry times to fetch response from client, as error flag.") return "client error\nFAIL", [] try: message_content = "\n".join([msg["content"][0].get("text") or "[image]" for msg in messages]) @@ -163,7 +177,6 @@ class GroundingAgentUitars: prediction = None try_times -= 1 temperature = 1 - top_k = -1 if prediction is None: return "client error\nFAIL", [] @@ -264,9 +277,9 @@ class GroundingAgentUitars: raise ValueError(f"Unsupported environment type: {environment_type}") def _format_messages_for_api(self, instruction: str, current_state: EnvState): - assert len(self.observations) == len(self.actions) and len(self.actions) == len( - self.thoughts - ), "The number of observations and actions should be the same." + assert len(self.observations) == len(self.actions) and len(self.actions) == len(self.thoughts), ( + "The number of observations and actions should be the same." + ) self.history_images.append(base64.b64decode(current_state.screenshot)) self.observations.append({"screenshot": current_state.screenshot, "accessibility_tree": None}) @@ -524,7 +537,7 @@ class GroundingAgentUitars: parsed_actions = [self.parse_action_string(action.replace("\n", "\\n").lstrip()) for action in all_action] actions: list[dict] = [] for action_instance, raw_str in zip(parsed_actions, all_action): - if action_instance == None: + if action_instance is None: print(f"Action can't parse: {raw_str}") raise ValueError(f"Action can't parse: {raw_str}") action_type = action_instance["function"] @@ -756,7 +769,7 @@ class GroundingAgentUitars: The pyautogui code string """ - pyautogui_code = f"import pyautogui\nimport time\n" + pyautogui_code = "import pyautogui\nimport time\n" actions = [] if isinstance(responses, dict): responses = [responses] @@ -774,7 +787,7 @@ class GroundingAgentUitars: if response_id == 0: pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n" else: - pyautogui_code += f"\ntime.sleep(1)\n" + pyautogui_code += "\ntime.sleep(1)\n" action_dict = response action_type = action_dict.get("action_type") @@ -846,17 +859,17 @@ class GroundingAgentUitars: if content: if input_swap: actions += TypeAction() - pyautogui_code += f"\nimport pyperclip" + pyautogui_code += "\nimport pyperclip" pyautogui_code += f"\npyperclip.copy('{stripped_content}')" - pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')" - pyautogui_code += f"\ntime.sleep(0.5)\n" + pyautogui_code += "\npyautogui.hotkey('ctrl', 'v')" + pyautogui_code += "\ntime.sleep(0.5)\n" if content.endswith("\n") or content.endswith("\\n"): - pyautogui_code += f"\npyautogui.press('enter')" + pyautogui_code += "\npyautogui.press('enter')" else: pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)" - pyautogui_code += f"\ntime.sleep(0.5)\n" + pyautogui_code += "\ntime.sleep(0.5)\n" if content.endswith("\n") or content.endswith("\\n"): - pyautogui_code += f"\npyautogui.press('enter')" + pyautogui_code += "\npyautogui.press('enter')" elif action_type in ["drag", "select"]: # Parsing drag or select action based on start and end_boxes @@ -869,9 +882,7 @@ class GroundingAgentUitars: x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2] ex = round(float((x1 + x2) / 2) * image_width, 3) ey = round(float((y1 + y2) / 2) * image_height, 3) - pyautogui_code += ( - f"\npyautogui.moveTo({sx}, {sy})\n" f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n" - ) + pyautogui_code += f"\npyautogui.moveTo({sx}, {sy})\n\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n" elif action_type == "scroll": # Parsing scroll action @@ -888,11 +899,11 @@ class GroundingAgentUitars: y = None direction = action_inputs.get("direction", "") - if x == None: + if x is None: if "up" in direction.lower(): - pyautogui_code += f"\npyautogui.scroll(5)" + pyautogui_code += "\npyautogui.scroll(5)" elif "down" in direction.lower(): - pyautogui_code += f"\npyautogui.scroll(-5)" + pyautogui_code += "\npyautogui.scroll(-5)" else: if "up" in direction.lower(): pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})" @@ -923,7 +934,7 @@ class GroundingAgentUitars: pyautogui_code += f"\npyautogui.moveTo({x}, {y})" elif action_type in ["finished"]: - pyautogui_code = f"DONE" + pyautogui_code = "DONE" else: pyautogui_code += f"\n# Unrecognized action type: {action_type}" diff --git a/src/khoj/processor/operator/operator_agent_anthropic.py b/src/khoj/processor/operator/operator_agent_anthropic.py index 9d4db42f..4d93b956 100644 --- a/src/khoj/processor/operator/operator_agent_anthropic.py +++ b/src/khoj/processor/operator/operator_agent_anthropic.py @@ -11,7 +11,32 @@ from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlo from khoj.database.models import ChatModel from khoj.processor.conversation.anthropic.utils import is_reasoning_model from khoj.processor.conversation.utils import AgentMessage -from khoj.processor.operator.operator_actions import * +from khoj.processor.operator.operator_actions import ( + BackAction, + ClickAction, + CursorPositionAction, + DoubleClickAction, + DragAction, + GotoAction, + HoldKeyAction, + KeypressAction, + MouseDownAction, + MouseUpAction, + MoveAction, + NoopAction, + OperatorAction, + Point, + ScreenshotAction, + ScrollAction, + TerminalAction, + TextEditorCreateAction, + TextEditorInsertAction, + TextEditorStrReplaceAction, + TextEditorViewAction, + TripleClickAction, + TypeAction, + WaitAction, +) from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent from khoj.processor.operator.operator_environment_base import ( EnvironmentType, @@ -518,7 +543,7 @@ class AnthropicOperatorAgent(OperatorAgent): def model_default_headers(self) -> list[str]: """Get the default computer use headers for the given model.""" if self.vision_model.name.startswith("claude-3-7-sonnet"): - return [f"computer-use-2025-01-24", "token-efficient-tools-2025-02-19"] + return ["computer-use-2025-01-24", "token-efficient-tools-2025-02-19"] elif self.vision_model.name.startswith("claude-sonnet-4") or self.vision_model.name.startswith("claude-opus-4"): return ["computer-use-2025-01-24"] else: @@ -538,7 +563,7 @@ class AnthropicOperatorAgent(OperatorAgent): * When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. * When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. * Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail. - * The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. + * The current date is {datetime.today().strftime("%A, %B %-d, %Y")}. * The current URL is {current_state.url}. @@ -563,7 +588,7 @@ class AnthropicOperatorAgent(OperatorAgent): - * The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. + * The current date is {datetime.today().strftime("%A, %B %-d, %Y")}. """ ).lstrip() diff --git a/src/khoj/processor/operator/operator_agent_base.py b/src/khoj/processor/operator/operator_agent_base.py index 38e7744e..17260a0c 100644 --- a/src/khoj/processor/operator/operator_agent_base.py +++ b/src/khoj/processor/operator/operator_agent_base.py @@ -1,6 +1,6 @@ import logging from abc import ABC, abstractmethod -from typing import List, Literal, Optional, Union +from typing import List, Optional from pydantic import BaseModel diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py index ccd34d39..44bfa7ca 100644 --- a/src/khoj/processor/operator/operator_agent_binary.py +++ b/src/khoj/processor/operator/operator_agent_binary.py @@ -12,7 +12,7 @@ from khoj.processor.conversation.utils import ( ) from khoj.processor.operator.grounding_agent import GroundingAgent from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars -from khoj.processor.operator.operator_actions import * +from khoj.processor.operator.operator_actions import OperatorAction, WaitAction from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent from khoj.processor.operator.operator_environment_base import ( EnvironmentType, @@ -181,7 +181,7 @@ class BinaryOperatorAgent(OperatorAgent): elif action.type == "key_down": rendered_parts += [f'**Action**: Press Key "{action.key}"'] elif action.type == "screenshot" and not current_state.screenshot: - rendered_parts += [f"**Error**: Failed to take screenshot"] + rendered_parts += ["**Error**: Failed to take screenshot"] elif action.type == "goto": rendered_parts += [f"**Action**: Open URL {action.url}"] else: @@ -317,7 +317,7 @@ class BinaryOperatorAgent(OperatorAgent): # Introduction * You are Khoj, a smart and resourceful web browsing assistant. You help the user accomplish their task using a web browser. * You are given the user's query and screenshots of the browser's state transitions. - * The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. + * The current date is {datetime.today().strftime("%A, %B %-d, %Y")}. * The current URL is {env_state.url}. # Your Task @@ -362,7 +362,7 @@ class BinaryOperatorAgent(OperatorAgent): # Introduction * You are Khoj, a smart and resourceful computer assistant. You help the user accomplish their task using a computer. * You are given the user's query and screenshots of the computer's state transitions. - * The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. + * The current date is {datetime.today().strftime("%A, %B %-d, %Y")}. # Your Task * First look at the screenshots carefully to notice all pertinent information. diff --git a/src/khoj/processor/operator/operator_agent_openai.py b/src/khoj/processor/operator/operator_agent_openai.py index 9752e574..a8d550f1 100644 --- a/src/khoj/processor/operator/operator_agent_openai.py +++ b/src/khoj/processor/operator/operator_agent_openai.py @@ -1,6 +1,5 @@ import json import logging -import platform from copy import deepcopy from datetime import datetime from textwrap import dedent @@ -10,7 +9,23 @@ from openai.types.responses import Response, ResponseOutputItem from khoj.database.models import ChatModel from khoj.processor.conversation.utils import AgentMessage -from khoj.processor.operator.operator_actions import * +from khoj.processor.operator.operator_actions import ( + BackAction, + ClickAction, + DoubleClickAction, + DragAction, + GotoAction, + KeypressAction, + MoveAction, + NoopAction, + OperatorAction, + Point, + RequestUserAction, + ScreenshotAction, + ScrollAction, + TypeAction, + WaitAction, +) from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent from khoj.processor.operator.operator_environment_base import ( EnvironmentType, @@ -152,7 +167,7 @@ class OpenAIOperatorAgent(OperatorAgent): # Add screenshot data in openai message format action_result["output"] = { "type": "input_image", - "image_url": f'data:image/webp;base64,{result_content["image"]}', + "image_url": f"data:image/webp;base64,{result_content['image']}", "current_url": result_content["url"], } elif action_result["type"] == "computer_call_output" and idx == len(env_steps) - 1: @@ -311,7 +326,7 @@ class OpenAIOperatorAgent(OperatorAgent): elif block.type == "function_call": if block.name == "goto": args = json.loads(block.arguments) - render_texts = [f'Open URL: {args.get("url", "[Missing URL]")}'] + render_texts = [f"Open URL: {args.get('url', '[Missing URL]')}"] else: render_texts += [block.name] elif block.type == "computer_call": @@ -351,7 +366,7 @@ class OpenAIOperatorAgent(OperatorAgent): * When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. * When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. * Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail. - * The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. + * The current date is {datetime.today().strftime("%A, %B %-d, %Y")}. * The current URL is {current_state.url}. @@ -374,7 +389,7 @@ class OpenAIOperatorAgent(OperatorAgent): - * The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. + * The current date is {datetime.today().strftime("%A, %B %-d, %Y")}. """ ).lstrip() diff --git a/src/khoj/processor/operator/operator_environment_browser.py b/src/khoj/processor/operator/operator_environment_browser.py index c28e53b7..8db72ade 100644 --- a/src/khoj/processor/operator/operator_environment_browser.py +++ b/src/khoj/processor/operator/operator_environment_browser.py @@ -247,7 +247,7 @@ class BrowserEnvironment(Environment): case "drag": if not isinstance(action, DragAction): - raise TypeError(f"Invalid action type for drag") + raise TypeError("Invalid action type for drag") path = action.path if not path: error = "Missing path for drag action" diff --git a/src/khoj/processor/operator/operator_environment_computer.py b/src/khoj/processor/operator/operator_environment_computer.py index b7dc9e8e..de946686 100644 --- a/src/khoj/processor/operator/operator_environment_computer.py +++ b/src/khoj/processor/operator/operator_environment_computer.py @@ -532,7 +532,7 @@ class ComputerEnvironment(Environment): else: return {"success": False, "output": process.stdout, "error": process.stderr} except asyncio.TimeoutError: - return {"success": False, "output": "", "error": f"Command timed out after 120 seconds."} + return {"success": False, "output": "", "error": "Command timed out after 120 seconds."} except Exception as e: return {"success": False, "output": "", "error": str(e)} diff --git a/src/khoj/processor/speech/text_to_speech.py b/src/khoj/processor/speech/text_to_speech.py index 3aa6bf72..9e4a0d0e 100644 --- a/src/khoj/processor/speech/text_to_speech.py +++ b/src/khoj/processor/speech/text_to_speech.py @@ -1,4 +1,3 @@ -import json # Used for working with JSON data import os import requests # Used for making HTTP requests diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index ef66d489..130bff30 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -385,7 +385,7 @@ async def read_webpages( tracer: dict = {}, ): "Infer web pages to read from the query and extract relevant information from them" - logger.info(f"Inferring web pages to read") + logger.info("Inferring web pages to read") urls = await infer_webpage_urls( query, max_webpages_to_read, diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 448ed699..2e3de666 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -93,7 +93,7 @@ async def run_code( # Run Code if send_status_func: - async for event in send_status_func(f"**Running code snippet**"): + async for event in send_status_func("**Running code snippet**"): yield {ChatEvent.STATUS: event} try: with timer("Chat actor: Execute generated program", logger, log_level=logging.INFO): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 44c2f2b7..f7345d94 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -7,7 +7,6 @@ from typing import List, Optional, Union import openai from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile -from fastapi.requests import Request from fastapi.responses import Response from starlette.authentication import has_required_scope, requires @@ -94,7 +93,7 @@ def update( logger.error(error_msg, exc_info=True) raise HTTPException(status_code=500, detail=error_msg) else: - logger.info(f"📪 Server indexed content updated via API") + logger.info("📪 Server indexed content updated via API") update_telemetry_state( request=request, diff --git a/src/khoj/routers/api_agents.py b/src/khoj/routers/api_agents.py index 3ac48d53..5c701e85 100644 --- a/src/khoj/routers/api_agents.py +++ b/src/khoj/routers/api_agents.py @@ -6,12 +6,11 @@ from typing import Dict, List, Optional from asgiref.sync import sync_to_async from fastapi import APIRouter, Request -from fastapi.requests import Request from fastapi.responses import Response from pydantic import BaseModel from starlette.authentication import has_required_scope, requires -from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters +from khoj.database.adapters import AgentAdapters, ConversationAdapters from khoj.database.models import Agent, Conversation, KhojUser, PriceTier from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt from khoj.utils.helpers import ( diff --git a/src/khoj/routers/api_automation.py b/src/khoj/routers/api_automation.py index b965c9f9..8630a324 100644 --- a/src/khoj/routers/api_automation.py +++ b/src/khoj/routers/api_automation.py @@ -109,7 +109,7 @@ def post_automation( except Exception as e: logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True) return Response( - content=f"Unable to create automation. Ensure the automation doesn't already exist.", + content="Unable to create automation. Ensure the automation doesn't already exist.", media_type="text/plain", status_code=500, ) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 8670d35e..34956ba1 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -10,7 +10,6 @@ from functools import partial from typing import Any, Dict, List, Optional from urllib.parse import unquote -from asgiref.sync import sync_to_async from fastapi import ( APIRouter, Depends, @@ -32,10 +31,10 @@ from khoj.database.adapters import ( PublicConversationAdapters, aget_user_name, ) -from khoj.database.models import Agent, ChatMessageModel, KhojUser +from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import is_local_api -from khoj.processor.conversation.prompts import help_message, no_entries_found +from khoj.processor.conversation.prompts import no_entries_found from khoj.processor.conversation.utils import ( OperatorRun, ResponseWithThought, @@ -65,11 +64,8 @@ from khoj.routers.helpers import ( acreate_title_from_history, agenerate_chat_response, aget_data_sources_and_output_format, - construct_automation_created_message, - create_automation, gather_raw_query_files, generate_mermaidjs_diagram, - generate_summary_from_files, get_conversation_command, get_message_from_queue, is_query_empty, @@ -89,13 +85,11 @@ from khoj.utils.helpers import ( convert_image_to_webp, get_country_code_from_timezone, get_country_name_from_timezone, - get_device, is_env_var_true, is_none_or_empty, is_operator_enabled, ) from khoj.utils.rawconfig import ( - ChatRequestBody, FileAttachment, FileFilterRequest, FilesFilterRequest, @@ -689,7 +683,6 @@ async def event_generator( region = body.region country = body.country or get_country_name_from_timezone(body.timezone) country_code = body.country_code or get_country_code_from_timezone(body.timezone) - timezone = body.timezone raw_images = body.images raw_query_files = body.files @@ -853,7 +846,8 @@ async def event_generator( if ( len(train_of_thought) > 0 and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value - and type(train_of_thought[-1]["data"]) == type(data) == str + and isinstance(train_of_thought[-1]["data"], str) + and isinstance(data, str) ): train_of_thought[-1]["data"] += data else: @@ -1075,11 +1069,11 @@ async def event_generator( # researched_results = await extract_relevant_info(q, researched_results, agent) if state.verbose > 1: - logger.debug(f'Researched Results: {"".join(r.summarizedResult or "" for r in research_results)}') + logger.debug(f"Researched Results: {''.join(r.summarizedResult or '' for r in research_results)}") # Gather Context ## Extract Document References - if not ConversationCommand.Research in conversation_commands: + if ConversationCommand.Research not in conversation_commands: try: async for result in search_documents( q, @@ -1218,7 +1212,7 @@ async def event_generator( else: code_results = result except ValueError as e: - program_execution_context.append(f"Failed to run code") + program_execution_context.append("Failed to run code") logger.warning( f"Failed to use code tool: {e}. Attempting to respond without code results", exc_info=True, @@ -1297,7 +1291,7 @@ async def event_generator( inferred_queries.append(improved_image_prompt) if generated_image is None or status_code != 200: program_execution_context.append(f"Failed to generate image with {improved_image_prompt}") - async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"): + async for result in send_event(ChatEvent.STATUS, "Failed to generate image"): yield result else: generated_images.append(generated_image) @@ -1315,7 +1309,7 @@ async def event_generator( yield result if ConversationCommand.Diagram in conversation_commands: - async for result in send_event(ChatEvent.STATUS, f"Creating diagram"): + async for result in send_event(ChatEvent.STATUS, "Creating diagram"): yield result inferred_queries = [] @@ -1372,7 +1366,7 @@ async def event_generator( return ## Generate Text Output - async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): + async for result in send_event(ChatEvent.STATUS, "**Generating a well-informed response**"): yield result llm_response, chat_metadata = await agenerate_chat_response( diff --git a/src/khoj/routers/api_model.py b/src/khoj/routers/api_model.py index 1c27f5f3..bb99d20b 100644 --- a/src/khoj/routers/api_model.py +++ b/src/khoj/routers/api_model.py @@ -3,7 +3,6 @@ import logging from typing import Dict, Optional, Union from fastapi import APIRouter, Request -from fastapi.requests import Request from fastapi.responses import Response from starlette.authentication import has_required_scope, requires diff --git a/src/khoj/routers/api_subscription.py b/src/khoj/routers/api_subscription.py index 11798958..b24e3a9f 100644 --- a/src/khoj/routers/api_subscription.py +++ b/src/khoj/routers/api_subscription.py @@ -117,7 +117,7 @@ async def subscribe(request: Request): ) logger.log(logging.INFO, f"🥳 New User Created: {user.user.uuid}") - logger.info(f'Stripe subscription {event["type"]} for {customer_email}') + logger.info(f"Stripe subscription {event['type']} for {customer_email}") return {"success": success} diff --git a/src/khoj/routers/email.py b/src/khoj/routers/email.py index ce9e0ae6..2395e68d 100644 --- a/src/khoj/routers/email.py +++ b/src/khoj/routers/email.py @@ -44,7 +44,7 @@ async def send_magic_link_email(email, unique_id, host): { "sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"), "to": email, - "subject": f"Your login code to Khoj", + "subject": "Your login code to Khoj", "html": html_content, } ) @@ -98,11 +98,11 @@ async def send_query_feedback(uquery, kquery, sentiment, user_email): user_email=user_email if not is_none_or_empty(user_email) else "N/A", ) # send feedback to fixed account - r = resend.Emails.send( + resend.Emails.send( { "sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"), "to": "team@khoj.dev", - "subject": f"User Feedback", + "subject": "User Feedback", "html": html_content, } ) @@ -127,7 +127,7 @@ def send_task_email(name, email, query, result, subject, is_image=False): r = resend.Emails.send( { - "sender": f'Khoj <{os.environ.get("RESEND_EMAIL", "khoj@khoj.dev")}>', + "sender": f"Khoj <{os.environ.get('RESEND_EMAIL', 'khoj@khoj.dev')}>", "to": email, "subject": f"✨ {subject}", "html": html_content, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 8dcda86b..8fce03aa 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,6 +1,5 @@ import asyncio import base64 -import concurrent.futures import fnmatch import hashlib import json @@ -47,14 +46,12 @@ from khoj.database.adapters import ( EntryAdapters, FileObjectAdapters, aget_user_by_email, - ais_user_subscribed, create_khoj_token, get_default_search_model, get_khoj_tokens, get_user_name, get_user_notion_config, get_user_subscription_state, - is_user_subscribed, run_with_process_lock, ) from khoj.database.models import ( @@ -160,7 +157,7 @@ def validate_chat_model(user: KhojUser): async def is_ready_to_chat(user: KhojUser): user_chat_model = await ConversationAdapters.aget_user_chat_model(user) - if user_chat_model == None: + if user_chat_model is None: user_chat_model = await ConversationAdapters.aget_default_chat_model(user) if ( @@ -581,7 +578,7 @@ async def generate_online_subqueries( ) return {q} return response - except Exception as e: + except Exception: logger.error(f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}") return {q} @@ -1172,8 +1169,8 @@ async def search_documents( agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent) if ( - not ConversationCommand.Notes in conversation_commands - and not ConversationCommand.Default in conversation_commands + ConversationCommand.Notes not in conversation_commands + and ConversationCommand.Default not in conversation_commands and not agent_has_entries ): yield compiled_references, inferred_queries, q @@ -1325,8 +1322,8 @@ async def extract_questions( logger.error(f"Invalid response for constructing subqueries: {response}") return [query] return queries - except: - logger.warning(f"LLM returned invalid JSON. Falling back to using user message as search query.") + except Exception: + logger.warning("LLM returned invalid JSON. Falling back to using user message as search query.") return [query] @@ -1351,7 +1348,7 @@ async def execute_search( return results if q is None or q == "": - logger.warning(f"No query param (q) passed in API call to initiate search") + logger.warning("No query param (q) passed in API call to initiate search") return results # initialize variables @@ -1364,7 +1361,7 @@ async def execute_search( if user: query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}" if query_cache_key in state.query_cache[user.uuid]: - logger.debug(f"Return response from query cache") + logger.debug("Return response from query cache") return state.query_cache[user.uuid][query_cache_key] # Encode query with filter terms removed @@ -1875,8 +1872,8 @@ class ApiUserRateLimiter: user: KhojUser = websocket.scope["user"].object subscribed = has_required_scope(websocket, ["premium"]) - current_window = "today" if self.window == 60 * 60 * 24 else f"now" - next_window = "tomorrow" if self.window == 60 * 60 * 24 else f"in a bit" + current_window = "today" if self.window == 60 * 60 * 24 else "now" + next_window = "tomorrow" if self.window == 60 * 60 * 24 else "in a bit" common_message_prefix = f"I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for {current_window}." # Remove requests outside of the time window @@ -2219,7 +2216,7 @@ def should_notify(original_query: str, executed_query: str, ai_response: str, us should_notify_result = response["decision"] == "Yes" reason = response.get("reason", "unknown") logger.info( - f'Decided to {"not " if not should_notify_result else ""}notify user of automation response because of reason: {reason}.' + f"Decided to {'not ' if not should_notify_result else ''}notify user of automation response because of reason: {reason}." ) return should_notify_result except Exception as e: @@ -2313,7 +2310,7 @@ def scheduled_chat( response_map = raw_response.json() ai_response = response_map.get("response") or response_map.get("image") is_image = False - if type(ai_response) == dict: + if isinstance(ai_response, dict): is_image = ai_response.get("image") is not None else: ai_response = raw_response.text @@ -2460,12 +2457,12 @@ async def aschedule_automation( def construct_automation_created_message(automation: Job, crontime: str, query_to_run: str, subject: str): # Display next run time in user timezone instead of UTC - schedule = f'{cron_descriptor.get_description(crontime)} {automation.next_run_time.strftime("%Z")}' + schedule = f"{cron_descriptor.get_description(crontime)} {automation.next_run_time.strftime('%Z')}" next_run_time = automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z") # Remove /automated_task prefix from inferred_query unprefixed_query_to_run = re.sub(r"^\/automated_task\s*", "", query_to_run) # Create the automation response - automation_icon_url = f"/static/assets/icons/automation.svg" + automation_icon_url = "/static/assets/icons/automation.svg" return f""" ### ![]({automation_icon_url}) Created Automation - Subject: **{subject}** @@ -2713,13 +2710,13 @@ def configure_content( t: Optional[state.SearchType] = state.SearchType.All, ) -> bool: success = True - if t == None: + if t is None: t = state.SearchType.All if t is not None and t in [type.value for type in state.SearchType]: t = state.SearchType(t) - if t is not None and not t.value in [type.value for type in state.SearchType]: + if t is not None and t.value not in [type.value for type in state.SearchType]: logger.warning(f"🚨 Invalid search type: {t}") return False @@ -2988,7 +2985,7 @@ async def grep_files( query += f" {' and '.join(context_info)}" if line_count > max_results: if lines_before or lines_after: - query += f" for" + query += " for" query += f" first {max_results} results" return query diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index d54a147f..d1ebe298 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -15,7 +15,6 @@ from khoj.processor.conversation.utils import ( ResearchIteration, ToolCall, construct_iteration_history, - construct_structured_message, construct_tool_chat_history, load_complex_json, ) @@ -24,7 +23,6 @@ from khoj.processor.tools.online_search import read_webpages_content, search_onl from khoj.processor.tools.run_code import run_code from khoj.routers.helpers import ( ChatEvent, - generate_summary_from_files, get_message_from_queue, grep_files, list_files, @@ -184,7 +182,7 @@ async def apick_next_tool( # TODO: Handle multiple tool calls. response_text = response.text parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0] - except Exception as e: + except Exception: # Otherwise assume the model has decided to end the research run and respond to the user. parsed_response = ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None) @@ -199,7 +197,7 @@ async def apick_next_tool( if i.warning is None and isinstance(i.query, ToolCall) } if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations: - warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different." + warning = "Repeated tool, query combination detected. Skipping iteration. Try something different." # Only send client status updates if we'll execute this iteration and model has thoughts to share. elif send_status_func and not is_none_or_empty(response.thought): async for event in send_status_func(response.thought): diff --git a/src/khoj/search_filter/base_filter.py b/src/khoj/search_filter/base_filter.py index ee6b10a5..304b4d1d 100644 --- a/src/khoj/search_filter/base_filter.py +++ b/src/khoj/search_filter/base_filter.py @@ -4,12 +4,10 @@ from typing import List class BaseFilter(ABC): @abstractmethod - def get_filter_terms(self, query: str) -> List[str]: - ... + def get_filter_terms(self, query: str) -> List[str]: ... def can_filter(self, raw_query: str) -> bool: return len(self.get_filter_terms(raw_query)) > 0 @abstractmethod - def defilter(self, query: str) -> str: - ... + def defilter(self, query: str) -> str: ... diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index b2b3453b..1698744c 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -9,9 +9,8 @@ from asgiref.sync import sync_to_async from sentence_transformers import util from khoj.database.adapters import EntryAdapters, get_default_search_model -from khoj.database.models import Agent +from khoj.database.models import Agent, KhojUser from khoj.database.models import Entry as DbEntry -from khoj.database.models import KhojUser from khoj.processor.content.text_to_entries import TextToEntries from khoj.utils import state from khoj.utils.helpers import get_absolute_path, timer diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 523ec007..a9f0c7f6 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -77,7 +77,7 @@ class AsyncIteratorWrapper: def is_none_or_empty(item): - return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == "" + return item is None or (hasattr(item, "__iter__") and len(item) == 0) or item == "" def to_snake_case_from_dash(item: str): @@ -97,7 +97,7 @@ def get_from_dict(dictionary, *args): Returns: dictionary[args[0]][args[1]]... or None if any keys missing""" current = dictionary for arg in args: - if not hasattr(current, "__iter__") or not arg in current: + if not hasattr(current, "__iter__") or arg not in current: return None current = current[arg] return current @@ -751,7 +751,7 @@ def is_valid_url(url: str) -> bool: try: result = urlparse(url.strip()) return all([result.scheme, result.netloc]) - except: + except Exception: return False @@ -759,7 +759,7 @@ def is_internet_connected(): try: response = requests.head("https://www.google.com") return response.status_code == 200 - except: + except Exception: return False diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index 8023b3ed..2dcc6c53 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -60,9 +60,7 @@ def initialization(interactive: bool = True): ] default_chat_models = known_available_models + other_available_models except Exception as e: - logger.warning( - f"⚠️ Failed to fetch {provider} chat models. Fallback to default models. Error: {str(e)}" - ) + logger.warning(f"⚠️ Failed to fetch {provider} chat models. Fallback to default models. Error: {str(e)}") # Set up OpenAI's online chat models openai_configured, openai_provider = _setup_chat_model_provider( diff --git a/src/khoj/utils/models.py b/src/khoj/utils/models.py index f848d4e4..6b80c705 100644 --- a/src/khoj/utils/models.py +++ b/src/khoj/utils/models.py @@ -8,12 +8,10 @@ from tqdm import trange class BaseEncoder(ABC): @abstractmethod - def __init__(self, model_name: str, device: torch.device = None, **kwargs): - ... + def __init__(self, model_name: str, device: torch.device = None, **kwargs): ... @abstractmethod - def encode(self, entries: List[str], device: torch.device = None, **kwargs) -> torch.Tensor: - ... + def encode(self, entries: List[str], device: torch.device = None, **kwargs) -> torch.Tensor: ... class OpenAI(BaseEncoder): diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 5377577b..eeca58bb 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -1,8 +1,7 @@ # System Packages import json import uuid -from pathlib import Path -from typing import Dict, List, Optional +from typing import List, Optional from pydantic import BaseModel diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 3b65a85b..3958173d 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -2,7 +2,7 @@ import os import threading from collections import defaultdict from pathlib import Path -from typing import Any, Dict, List +from typing import Dict, List from apscheduler.schedulers.background import BackgroundScheduler from openai import OpenAI diff --git a/src/telemetry/telemetry.py b/src/telemetry/telemetry.py index 576befd0..02bda3dc 100644 --- a/src/telemetry/telemetry.py +++ b/src/telemetry/telemetry.py @@ -30,7 +30,7 @@ def v1_telemetry(telemetry_data: List[Dict[str, str]]): try: for row in telemetry_data: posthog.capture(row["server_id"], "api_request", row) - except Exception as e: + except Exception: raise HTTPException( status_code=500, detail="Could not POST equest to new khoj telemetry server. Contact developer to get this fixed.", diff --git a/tests/conftest.py b/tests/conftest.py index dd448bd1..1b33e94d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -326,7 +326,7 @@ File statistics: - Code examples: Yes - Purpose: Stress testing atomic agent updates -{'Additional padding content. ' * 20} +{"Additional padding content. " * 20} End of file {i}. """ diff --git a/tests/evals/eval.py b/tests/evals/eval.py index b93d2667..85188d2e 100644 --- a/tests/evals/eval.py +++ b/tests/evals/eval.py @@ -462,7 +462,7 @@ def evaluate_response_with_gemini( Ground Truth: {ground_truth} Provide your evaluation in the following json format: - {"explanation:" "[How you made the decision?)", "decision:" "(TRUE if response contains key information, FALSE otherwise)"} + {"explanation:[How you made the decision?)", "decision:(TRUE if response contains key information, FALSE otherwise)"} """ gemini_api_url = ( f"https://generativelanguage.googleapis.com/v1beta/models/{eval_model}:generateContent?key={GEMINI_API_KEY}" @@ -557,7 +557,7 @@ def process_batch(batch, batch_start, results, dataset_length, response_evaluato --------- Decision: {colored_decision} Accuracy: {running_accuracy:.2%} -Progress: {running_total_count.get()/dataset_length:.2%} +Progress: {running_total_count.get() / dataset_length:.2%} Index: {current_index} Question: {prompt} Expected Answer: {answer} diff --git a/tests/test_agents.py b/tests/test_agents.py index 1d3b96ec..cad2e888 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -20,7 +20,7 @@ def test_create_default_agent(default_user: KhojUser): assert agent.input_tools == [] assert agent.output_modes == [] assert agent.privacy_level == Agent.PrivacyLevel.PUBLIC - assert agent.managed_by_admin == True + assert agent.managed_by_admin @pytest.mark.anyio @@ -178,7 +178,7 @@ async def test_multiple_agents_with_knowledge_base_and_users( default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser ): full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown") - new_agent = await AgentAdapters.aupdate_agent( + await AgentAdapters.aupdate_agent( default_user2, "Test Agent", "Test Personality", @@ -290,17 +290,17 @@ async def test_large_knowledge_base_atomic_update( assert len(final_entries) > initial_entries_count, "Should have more entries after update" # With 180 files, we should have many entries (each file creates multiple entries) - assert ( - len(final_entries) >= expected_file_count - ), f"Expected at least {expected_file_count} entries, got {len(final_entries)}" + assert len(final_entries) >= expected_file_count, ( + f"Expected at least {expected_file_count} entries, got {len(final_entries)}" + ) # Verify no partial state - all entries should correspond to the final file set entry_file_paths = {entry.file_path for entry in final_entries} # All file objects should have corresponding entries - assert file_paths_in_db.issubset( - entry_file_paths - ), "All file objects should have corresponding entries - atomic update verification" + assert file_paths_in_db.issubset(entry_file_paths), ( + "All file objects should have corresponding entries - atomic update verification" + ) # Additional stress test: verify referential integrity # Count entries per file to ensure no partial file processing @@ -333,7 +333,7 @@ async def test_concurrent_agent_updates_atomicity( test_files = available_files # Use all available files for the stress test # Create initial agent - agent = await AgentAdapters.aupdate_agent( + await AgentAdapters.aupdate_agent( default_user2, "Concurrent Test Agent", "Test concurrent updates", @@ -391,14 +391,14 @@ async def test_concurrent_agent_updates_atomicity( file_object_paths = {fo.file_name for fo in final_file_objects} # All entries should have corresponding file objects - assert entry_file_paths.issubset( - file_object_paths - ), "All entries should have corresponding file objects - indicates atomic update worked" + assert entry_file_paths.issubset(file_object_paths), ( + "All entries should have corresponding file objects - indicates atomic update worked" + ) except Exception as e: # If we get database integrity errors, that's actually expected behavior # with proper atomic transactions - they should fail cleanly rather than # allowing partial updates - assert ( - "database" in str(e).lower() or "integrity" in str(e).lower() - ), f"Expected database/integrity error with concurrent updates, got: {e}" + assert "database" in str(e).lower() or "integrity" in str(e).lower(), ( + f"Expected database/integrity error with concurrent updates, got: {e}" + ) diff --git a/tests/test_client.py b/tests/test_client.py index 46732a86..a9c8f973 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -5,7 +5,6 @@ from urllib.parse import quote import pytest from fastapi import FastAPI from fastapi.testclient import TestClient -from PIL import Image from khoj.configure import configure_routes, configure_search_types from khoj.database.adapters import EntryAdapters @@ -101,7 +100,7 @@ def test_update_with_invalid_content_type(client): headers = {"Authorization": "Bearer kk-secret"} # Act - response = client.get(f"/api/update?t=invalid_content_type", headers=headers) + response = client.get("/api/update?t=invalid_content_type", headers=headers) # Assert assert response.status_code == 422 @@ -114,7 +113,7 @@ def test_regenerate_with_invalid_content_type(client): headers = {"Authorization": "Bearer kk-secret"} # Act - response = client.get(f"/api/update?force=true&t=invalid_content_type", headers=headers) + response = client.get("/api/update?force=true&t=invalid_content_type", headers=headers) # Assert assert response.status_code == 422 @@ -238,13 +237,13 @@ def test_regenerate_with_valid_content_type(client): def test_regenerate_with_github_fails_without_pat(client): # Act headers = {"Authorization": "Bearer kk-secret"} - response = client.get(f"/api/update?force=true&t=github", headers=headers) + response = client.get("/api/update?force=true&t=github", headers=headers) # Arrange files = get_sample_files_data() # Act - response = client.patch(f"/api/content?t=github", files=files, headers=headers) + response = client.patch("/api/content?t=github", files=files, headers=headers) # Assert assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github" @@ -270,7 +269,7 @@ def test_get_api_config_types(client, sample_org_data, default_user: KhojUser): text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user) # Act - response = client.get(f"/api/content/types", headers=headers) + response = client.get("/api/content/types", headers=headers) # Assert assert response.status_code == 200 @@ -286,7 +285,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI): client = TestClient(fastapi_app) # Act - response = client.get(f"/api/content/types") + response = client.get("/api/content/types") # Assert assert response.status_code == 200 @@ -454,8 +453,8 @@ def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojAp headers = {"Authorization": f"Bearer {api_user2.token}"} # Act - auth_response = chat_client_with_auth.post(f"/api/chat", json={"q": query}, headers=headers) - no_auth_response = chat_client_with_auth.post(f"/api/chat", json={"q": query}) + auth_response = chat_client_with_auth.post("/api/chat", json={"q": query}, headers=headers) + no_auth_response = chat_client_with_auth.post("/api/chat", json={"q": query}) # Assert assert auth_response.status_code == 200 diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index e18ae7ba..f50b7fdb 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -77,12 +77,12 @@ class TestTruncateMessage: # Assert # The original object has been modified. Verify certain properties - assert ( - len(chat_history) == 1 - ), "Only most recent message should be present as it itself is larger than context size" - assert len(truncated_chat_history[0].content) < len( - copy_big_chat_message.content - ), "message content list should be modified" + assert len(chat_history) == 1, ( + "Only most recent message should be present as it itself is larger than context size" + ) + assert len(truncated_chat_history[0].content) < len(copy_big_chat_message.content), ( + "message content list should be modified" + ) assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved" assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" @@ -101,9 +101,9 @@ class TestTruncateMessage: # Assert # The original object has been modified. Verify certain properties - assert ( - len(chat_history) == 1 - ), "Only most recent message should be present as it itself is larger than context size" + assert len(chat_history) == 1, ( + "Only most recent message should be present as it itself is larger than context size" + ) assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" @@ -150,9 +150,9 @@ class TestTruncateMessage: # The original object has been modified. Verify certain properties assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size" - assert ( - len(chat_messages) == 1 - ), "Only most recent message should be present as it itself is larger than context size" + assert len(chat_messages) == 1, ( + "Only most recent message should be present as it itself is larger than context size" + ) assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" @@ -172,9 +172,9 @@ class TestTruncateMessage: # The original object has been modified. Verify certain properties assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size" - assert ( - len(chat_messages) == 1 - ), "Only most recent message should be present as it itself is larger than context size" + assert len(chat_messages) == 1, ( + "Only most recent message should be present as it itself is larger than context size" + ) assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index 1e11348f..0cd7a306 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -162,15 +162,15 @@ def test_date_extraction(): assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], "Expected d.m.Y structured date to be extracted" extracted_dates = DateFilter().extract_dates("CLOCK: [1984-04-01 Sun 09:50]--[1984-04-01 Sun 10:10] => 24:20") - assert extracted_dates == [ - datetime(1984, 4, 1, 0, 0, 0) - ], "Expected single deduplicated date extracted from logbook entry" + assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], ( + "Expected single deduplicated date extracted from logbook entry" + ) extracted_dates = DateFilter().extract_dates("CLOCK: [1984/03/31 mer 09:50]--[1984/04/01 mer 10:10] => 24:20") expected_dates = [datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 3, 31, 0, 0, 0)] - assert all( - [dt in extracted_dates for dt in expected_dates] - ), "Expected multiple different dates extracted from logbook entry" + assert all([dt in extracted_dates for dt in expected_dates]), ( + "Expected multiple different dates extracted from logbook entry" + ) def test_natual_date_extraction(): @@ -187,9 +187,9 @@ def test_natual_date_extraction(): assert datetime(1984, 4, 4, 0, 0, 0) in extracted_dates, "Expected natural date to be extracted" extracted_dates = DateFilter().extract_dates("head 11th april 1984 tail") - assert ( - datetime(1984, 4, 11, 0, 0, 0) in extracted_dates - ), "Expected natural date with lowercase month to be extracted" + assert datetime(1984, 4, 11, 0, 0, 0) in extracted_dates, ( + "Expected natural date with lowercase month to be extracted" + ) extracted_dates = DateFilter().extract_dates("head 23rd april 84 tail") assert datetime(1984, 4, 23, 0, 0, 0) in extracted_dates, "Expected natural date with 2-digit year to be extracted" @@ -201,16 +201,16 @@ def test_natual_date_extraction(): assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], "Expected partial natural date to be extracted" extracted_dates = DateFilter().extract_dates("head Apr 1984 tail") - assert extracted_dates == [ - datetime(1984, 4, 1, 0, 0, 0) - ], "Expected partial natural date with short month to be extracted" + assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], ( + "Expected partial natural date with short month to be extracted" + ) extracted_dates = DateFilter().extract_dates("head apr 1984 tail") - assert extracted_dates == [ - datetime(1984, 4, 1, 0, 0, 0) - ], "Expected partial natural date with lowercase month to be extracted" + assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], ( + "Expected partial natural date with lowercase month to be extracted" + ) extracted_dates = DateFilter().extract_dates("head apr 84 tail") - assert extracted_dates == [ - datetime(1984, 4, 1, 0, 0, 0) - ], "Expected partial natural date with 2-digit year to be extracted" + assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], ( + "Expected partial natural date with 2-digit year to be extracted" + ) diff --git a/tests/test_image_to_entries.py b/tests/test_image_to_entries.py index 77254c0a..24bf8192 100644 --- a/tests/test_image_to_entries.py +++ b/tests/test_image_to_entries.py @@ -1,5 +1,3 @@ -import os - from khoj.processor.content.images.image_to_entries import ImageToEntries diff --git a/tests/test_markdown_to_entries.py b/tests/test_markdown_to_entries.py index b8ab37a7..3b60e6da 100644 --- a/tests/test_markdown_to_entries.py +++ b/tests/test_markdown_to_entries.py @@ -8,7 +8,7 @@ from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntrie def test_extract_markdown_with_no_headings(tmp_path): "Convert markdown file with no heading to entry format." # Arrange - entry = f""" + entry = """ - Bullet point 1 - Bullet point 2 """ @@ -35,7 +35,7 @@ def test_extract_markdown_with_no_headings(tmp_path): def test_extract_single_markdown_entry(tmp_path): "Convert markdown from single file to entry format." # Arrange - entry = f"""### Heading + entry = """### Heading \t\r Body Line 1 """ @@ -55,7 +55,7 @@ def test_extract_single_markdown_entry(tmp_path): def test_extract_multiple_markdown_entries(tmp_path): "Convert multiple markdown from single file to entry format." # Arrange - entry = f""" + entry = """ ### Heading 1 \t\r Heading 1 Body Line 1 @@ -81,7 +81,7 @@ def test_extract_multiple_markdown_entries(tmp_path): def test_extract_entries_with_different_level_headings(tmp_path): "Extract markdown entries with different level headings." # Arrange - entry = f""" + entry = """ # Heading 1 ## Sub-Heading 1.1 # Heading 2 @@ -104,7 +104,7 @@ def test_extract_entries_with_different_level_headings(tmp_path): def test_extract_entries_with_non_incremental_heading_levels(tmp_path): "Extract markdown entries when deeper child level before shallower child level." # Arrange - entry = f""" + entry = """ # Heading 1 #### Sub-Heading 1.1 ## Sub-Heading 1.2 @@ -129,7 +129,7 @@ def test_extract_entries_with_non_incremental_heading_levels(tmp_path): def test_extract_entries_with_text_before_headings(tmp_path): "Extract markdown entries with some text before any headings." # Arrange - entry = f""" + entry = """ Text before headings # Heading 1 body line 1 @@ -149,15 +149,15 @@ body line 2 assert len(entries[1]) == 3 assert entries[1][0].raw == "\nText before headings" assert entries[1][1].raw == "# Heading 1\nbody line 1" - assert ( - entries[1][2].raw == "# Heading 1\n## Heading 2\nbody line 2\n" - ), "Ensure raw entry includes heading ancestory" + assert entries[1][2].raw == "# Heading 1\n## Heading 2\nbody line 2\n", ( + "Ensure raw entry includes heading ancestory" + ) def test_parse_markdown_file_into_single_entry_if_small(tmp_path): "Parse markdown file into single entry if it fits within the token limits." # Arrange - entry = f""" + entry = """ # Heading 1 body line 1 ## Subheading 1.1 @@ -180,7 +180,7 @@ body line 1.1 def test_parse_markdown_entry_with_children_as_single_entry_if_small(tmp_path): "Parse markdown entry with child headings as single entry if it fits within the tokens limits." # Arrange - entry = f""" + entry = """ # Heading 1 body line 1 ## Subheading 1.1 @@ -201,13 +201,13 @@ longer body line 2.1 # Assert assert len(entries) == 2 assert len(entries[1]) == 3 - assert ( - entries[1][0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1" - ), "First entry includes children headings" + assert entries[1][0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1", ( + "First entry includes children headings" + ) assert entries[1][1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings" - assert ( - entries[1][2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n" - ), "Third entry is second entries child heading" + assert entries[1][2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n", ( + "Third entry is second entries child heading" + ) def test_line_number_tracking_in_recursive_split(): @@ -252,14 +252,16 @@ def test_line_number_tracking_in_recursive_split(): assert entry.uri is not None, f"Entry '{entry}' has a None URI." assert match is not None, f"URI format is incorrect: {entry.uri}" - assert ( - filepath_from_uri == markdown_file_path - ), f"File path in URI '{filepath_from_uri}' does not match expected '{markdown_file_path}'" + assert filepath_from_uri == markdown_file_path, ( + f"File path in URI '{filepath_from_uri}' does not match expected '{markdown_file_path}'" + ) # Ensure the first non-heading line in the compiled entry matches the line in the file assert ( cleaned_first_entry_line in line_in_file.strip() or cleaned_first_entry_line in next_line_in_file.strip() - ), f"First non-heading line '{cleaned_first_entry_line}' in {entry.raw} does not match line {line_number_from_uri} in file: '{line_in_file}' or next line '{next_line_in_file}'" + ), ( + f"First non-heading line '{cleaned_first_entry_line}' in {entry.raw} does not match line {line_number_from_uri} in file: '{line_in_file}' or next line '{next_line_in_file}'" + ) # Helper Functions diff --git a/tests/test_online_chat_actors.py b/tests/test_online_chat_actors.py index 4afd0828..fea5d461 100644 --- a/tests/test_online_chat_actors.py +++ b/tests/test_online_chat_actors.py @@ -343,12 +343,12 @@ Expenses:Food:Dining 10.00 USD""", "file": "Ledger.org", }, { - "compiled": f"""2020-04-01 "SuperMercado" "Bananas" + "compiled": """2020-04-01 "SuperMercado" "Bananas" Expenses:Food:Groceries 10.00 USD""", "file": "Ledger.org", }, { - "compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner" + "compiled": """2020-01-01 "Naco Taco" "Burittos for Dinner" Expenses:Food:Dining 10.00 USD""", "file": "Ledger.org", }, @@ -389,12 +389,12 @@ Expenses:Food:Dining 10.00 USD""", "file": "Ledger.md", }, { - "compiled": f"""2020-04-01 "SuperMercado" "Bananas" + "compiled": """2020-04-01 "SuperMercado" "Bananas" Expenses:Food:Groceries 10.00 USD""", "file": "Ledger.md", }, { - "compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner" + "compiled": """2020-01-01 "Naco Taco" "Burittos for Dinner" Expenses:Food:Dining 10.00 USD""", "file": "Ledger.md", }, @@ -452,17 +452,17 @@ async def test_ask_for_clarification_if_not_enough_context_in_question(): # Arrange context = [ { - "compiled": f"""# Ramya + "compiled": """# Ramya My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""", "file": "Family.md", }, { - "compiled": f"""# Fang + "compiled": """# Fang My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""", "file": "Family.md", }, { - "compiled": f"""# Aiyla + "compiled": """# Aiyla My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""", "file": "Family.md", }, @@ -497,9 +497,9 @@ async def test_agent_prompt_should_be_used(openai_agent): "Chat actor should ask be tuned to think like an accountant based on the agent definition" # Arrange context = [ - {"compiled": f"""I went to the store and bought some bananas for 2.20""", "file": "Ledger.md"}, - {"compiled": f"""I went to the store and bought some apples for 1.30""", "file": "Ledger.md"}, - {"compiled": f"""I went to the store and bought some oranges for 6.00""", "file": "Ledger.md"}, + {"compiled": """I went to the store and bought some bananas for 2.20""", "file": "Ledger.md"}, + {"compiled": """I went to the store and bought some apples for 1.30""", "file": "Ledger.md"}, + {"compiled": """I went to the store and bought some oranges for 6.00""", "file": "Ledger.md"}, ] expected_responses = ["9.50", "9.5"] @@ -539,13 +539,13 @@ async def test_websearch_with_operators(chat_client, default_user2): responses = await generate_online_subqueries(user_query, [], None, default_user2) # Assert - assert any( - ["reddit.com/r/worldnews" in response for response in responses] - ), "Expected a search query to include site:reddit.com but got: " + str(responses) + assert any(["reddit.com/r/worldnews" in response for response in responses]), ( + "Expected a search query to include site:reddit.com but got: " + str(responses) + ) - assert any( - ["site:reddit.com" in response for response in responses] - ), "Expected a search query to include site:reddit.com but got: " + str(responses) + assert any(["site:reddit.com" in response for response in responses]), ( + "Expected a search query to include site:reddit.com but got: " + str(responses) + ) # ---------------------------------------------------------------------------------------------------- @@ -559,9 +559,9 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client, default_u responses = await generate_online_subqueries(user_query, [], None, default_user2) # Assert - assert any( - ["site:khoj.dev" in response for response in responses] - ), "Expected search query to include site:khoj.dev but got: " + str(responses) + assert any(["site:khoj.dev" in response for response in responses]), ( + "Expected search query to include site:khoj.dev but got: " + str(responses) + ) # ---------------------------------------------------------------------------------------------------- @@ -693,9 +693,9 @@ def test_infer_task_scheduling_request( for expected_q in expected_qs: assert expected_q in inferred_query, f"Expected fragment {expected_q} in query: {inferred_query}" for unexpected_q in unexpected_qs: - assert ( - unexpected_q not in inferred_query - ), f"Did not expect fragment '{unexpected_q}' in query: '{inferred_query}'" + assert unexpected_q not in inferred_query, ( + f"Did not expect fragment '{unexpected_q}' in query: '{inferred_query}'" + ) # ---------------------------------------------------------------------------------------------------- diff --git a/tests/test_online_chat_director.py b/tests/test_online_chat_director.py index b5d3331e..b07f345e 100644 --- a/tests/test_online_chat_director.py +++ b/tests/test_online_chat_director.py @@ -33,7 +33,7 @@ def create_conversation(message_list, user, agent=None): @pytest.mark.django_db(transaction=True) def test_chat_with_no_chat_history_or_retrieved_content(chat_client): # Act - response = chat_client.post(f"/api/chat", json={"q": "Hello, my name is Testatron. Who are you?"}) + response = chat_client.post("/api/chat", json={"q": "Hello, my name is Testatron. Who are you?"}) response_message = response.json()["response"] # Assert @@ -50,7 +50,7 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client): def test_chat_with_online_content(chat_client): # Act q = "/online give me the link to paul graham's essay how to do great work" - response = chat_client.post(f"/api/chat?", json={"q": q}) + response = chat_client.post("/api/chat?", json={"q": q}) response_message = response.json()["response"] # Assert @@ -59,9 +59,9 @@ def test_chat_with_online_content(chat_client): "paulgraham.com/hwh.html", ] assert response.status_code == 200 - assert any( - [expected_response in response_message for expected_response in expected_responses] - ), f"Expected links: {expected_responses}. Actual response: {response_message}" + assert any([expected_response in response_message for expected_response in expected_responses]), ( + f"Expected links: {expected_responses}. Actual response: {response_message}" + ) # ---------------------------------------------------------------------------------------------------- @@ -70,15 +70,15 @@ def test_chat_with_online_content(chat_client): def test_chat_with_online_webpage_content(chat_client): # Act q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?" - response = chat_client.post(f"/api/chat", json={"q": q}) + response = chat_client.post("/api/chat", json={"q": q}) response_message = response.json()["response"] # Assert expected_responses = ["185", "1871", "horse"] assert response.status_code == 200 - assert any( - [expected_response in response_message for expected_response in expected_responses] - ), f"Expected links: {expected_responses}. Actual response: {response_message}" + assert any([expected_response in response_message for expected_response in expected_responses]), ( + f"Expected links: {expected_responses}. Actual response: {response_message}" + ) # ---------------------------------------------------------------------------------------------------- @@ -93,7 +93,7 @@ def test_answer_from_chat_history(chat_client, default_user2: KhojUser): create_conversation(message_list, default_user2) # Act - response = chat_client.post(f"/api/chat", json={"q": "What is my name?"}) + response = chat_client.post("/api/chat", json={"q": "What is my name?"}) response_message = response.content.decode("utf-8") # Assert @@ -120,7 +120,7 @@ def test_answer_from_currently_retrieved_content(chat_client, default_user2: Kho create_conversation(message_list, default_user2) # Act - response = chat_client.post(f"/api/chat", json={"q": "Where was Xi Li born?"}) + response = chat_client.post("/api/chat", json={"q": "Where was Xi Li born?"}) response_message = response.json()["response"] # Assert @@ -144,7 +144,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_n create_conversation(message_list, default_user2) # Act - response = chat_client_no_background.post(f"/api/chat", json={"q": "Where was I born?"}) + response = chat_client_no_background.post("/api/chat", json={"q": "Where was I born?"}) response_message = response.json()["response"] # Assert @@ -167,7 +167,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, d create_conversation(message_list, default_user2) # Act - response = chat_client.post(f"/api/chat", json={"q": "Where was I born?"}) + response = chat_client.post("/api/chat", json={"q": "Where was I born?"}) response_message = response.json()["response"] # Assert @@ -192,7 +192,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use create_conversation(message_list, default_user2) # Act - response = chat_client.post(f"/api/chat", json={"q": "Where was I born?"}) + response = chat_client.post("/api/chat", json={"q": "Where was I born?"}) response_message = response.json()["response"] # Assert @@ -222,7 +222,7 @@ def test_answer_using_general_command(chat_client, default_user2: KhojUser): create_conversation(message_list, default_user2) # Act - response = chat_client.post(f"/api/chat", json={"q": query, "stream": True}) + response = chat_client.post("/api/chat", json={"q": query, "stream": True}) response_message = response.content.decode("utf-8") # Assert @@ -240,7 +240,7 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client, default_ create_conversation(message_list, default_user2) # Act - response = chat_client.post(f"/api/chat", json={"q": query}) + response = chat_client.post("/api/chat", json={"q": query}) response_message = response.json()["response"] # Assert @@ -258,7 +258,7 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default create_conversation(message_list, default_user2) # Act - response = chat_client_no_background.post(f"/api/chat", json={"q": query}) + response = chat_client_no_background.post("/api/chat", json={"q": query}) response_message = response.json()["response"] # Assert @@ -291,7 +291,7 @@ def test_summarize_one_file(chat_client, default_user2: KhojUser): json={"filename": summarization_file, "conversation_id": str(conversation.id)}, ) query = "/summarize" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) + response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response_message = response.json()["response"] # Assert assert response_message != "" @@ -322,7 +322,7 @@ def test_summarize_extra_text(chat_client, default_user2: KhojUser): json={"filename": summarization_file, "conversation_id": str(conversation.id)}, ) query = "/summarize tell me about Xiu" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) + response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response_message = response.json()["response"] # Assert assert response_message != "" @@ -349,7 +349,7 @@ def test_summarize_multiple_files(chat_client, default_user2: KhojUser): ) query = "/summarize" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) + response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response_message = response.json()["response"] # Assert @@ -365,7 +365,7 @@ def test_summarize_no_files(chat_client, default_user2: KhojUser): # Act query = "/summarize" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) + response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response_message = response.json()["response"] # Assert @@ -400,11 +400,11 @@ def test_summarize_different_conversation(chat_client, default_user2: KhojUser): # Act query = "/summarize" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation2.id)}) + response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation2.id)}) response_message_conv2 = response.json()["response"] # now make sure that the file filter is still in conversation 1 - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation1.id)}) + response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation1.id)}) response_message_conv1 = response.json()["response"] # Assert @@ -430,7 +430,7 @@ def test_summarize_nonexistant_file(chat_client, default_user2: KhojUser): json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)}, ) query = urllib.parse.quote("/summarize") - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) + response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response_message = response.json()["response"] # Assert assert response_message == "No files selected for summarization. Please add files using the section on the left." @@ -462,7 +462,7 @@ def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_confi # Act query = "/summarize" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) + response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response_message = response.json()["response"] # Assert @@ -477,7 +477,7 @@ def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_confi def test_answer_requires_current_date_awareness(chat_client): "Chat actor should be able to answer questions relative to current date using provided notes" # Act - response = chat_client.post(f"/api/chat", json={"q": "Where did I have lunch today?", "stream": True}) + response = chat_client.post("/api/chat", json={"q": "Where did I have lunch today?", "stream": True}) response_message = response.content.decode("utf-8") # Assert @@ -496,7 +496,7 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_clien "Chat director should be able to answer questions that require date aware aggregation across multiple notes" # Act query = "How much did I spend on dining this year?" - response = chat_client.post(f"/api/chat", json={"q": query}) + response = chat_client.post("/api/chat", json={"q": query}) response_message = response.json()["response"] # Assert @@ -518,7 +518,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c # Act query = "Write a haiku about unit testing. Do not say anything else." - response = chat_client.post(f"/api/chat", json={"q": query}) + response = chat_client.post("/api/chat", json={"q": query}) response_message = response.json()["response"] # Assert @@ -536,7 +536,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background): # Act query = "What is the name of Namitas older son?" - response = chat_client_no_background.post(f"/api/chat", json={"q": query}) + response = chat_client_no_background.post("/api/chat", json={"q": query}) response_message = response.json()["response"].lower() # Assert @@ -571,7 +571,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user # Act query = "What is my name?" - response = chat_client.post(f"/api/chat", json={"q": query}) + response = chat_client.post("/api/chat", json={"q": query}) response_message = response.json()["response"] # Assert @@ -604,9 +604,7 @@ def test_answer_in_chat_history_by_conversation_id(chat_client, default_user2: K # Act query = "/general What is my favorite color?" - response = chat_client.post( - f"/api/chat", json={"q": query, "conversation_id": str(conversation.id), "stream": True} - ) + response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id), "stream": True}) response_message = response.content.decode("utf-8") # Assert @@ -639,7 +637,7 @@ def test_answer_in_chat_history_by_conversation_id_with_agent( # Act query = "/general What did I buy for breakfast?" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) + response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response_message = response.json()["response"] # Assert that agent only responds with the summary of spending @@ -657,7 +655,7 @@ def test_answer_requires_multiple_independent_searches(chat_client): "Chat director should be able to answer by doing multiple independent searches for required information" # Act query = "Is Xi Li older than Namita? Just say the older persons full name" - response = chat_client.post(f"/api/chat", json={"q": query}) + response = chat_client.post("/api/chat", json={"q": query}) response_message = response.json()["response"].lower() # Assert @@ -681,7 +679,7 @@ def test_answer_using_file_filter(chat_client): query = ( 'Is Xi Li older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"' ) - response = chat_client.post(f"/api/chat", json={"q": query}) + response = chat_client.post("/api/chat", json={"q": query}) response_message = response.json()["response"].lower() # Assert diff --git a/tests/test_org_to_entries.py b/tests/test_org_to_entries.py index 0196ef6c..61d33575 100644 --- a/tests/test_org_to_entries.py +++ b/tests/test_org_to_entries.py @@ -12,7 +12,7 @@ def test_configure_indexing_heading_only_entries(tmp_path): """Ensure entries with empty body are ignored, unless explicitly configured to index heading entries. Property drawers not considered Body. Ignore control characters for evaluating if Body empty.""" # Arrange - entry = f"""*** Heading + entry = """*** Heading :PROPERTIES: :ID: 42-42-42 :END: @@ -74,7 +74,7 @@ def test_entry_split_when_exceeds_max_tokens(): "Ensure entries with compiled words exceeding max_tokens are split." # Arrange tmp_path = "/tmp/test.org" - entry = f"""*** Heading + entry = """*** Heading \t\r Body Line """ @@ -99,7 +99,7 @@ def test_entry_split_when_exceeds_max_tokens(): def test_entry_split_drops_large_words(): "Ensure entries drops words larger than specified max word length from compiled version." # Arrange - entry_text = f"""First Line + entry_text = """First Line dog=1\n\r\t cat=10 car=4 @@ -124,7 +124,7 @@ book=2 def test_parse_org_file_into_single_entry_if_small(tmp_path): "Parse org file into single entry if it fits within the token limits." # Arrange - original_entry = f""" + original_entry = """ * Heading 1 body line 1 ** Subheading 1.1 @@ -133,7 +133,7 @@ body line 1.1 data = { f"{tmp_path}": original_entry, } - expected_entry = f""" + expected_entry = """ * Heading 1 body line 1 @@ -155,7 +155,7 @@ body line 1.1 def test_parse_org_entry_with_children_as_single_entry_if_small(tmp_path): "Parse org entry with child headings as single entry only if it fits within the tokens limits." # Arrange - entry = f""" + entry = """ * Heading 1 body line 1 ** Subheading 1.1 @@ -205,7 +205,7 @@ longer body line 2.1 def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path): "Parse org sibling entries as separate entries only if it fits within the tokens limits." # Arrange - entry = f""" + entry = """ * Heading 1 body line 1 ** Subheading 1.1 @@ -267,7 +267,7 @@ body line 3.1 def test_entry_with_body_to_entry(tmp_path): "Ensure entries with valid body text are loaded." # Arrange - entry = f"""*** Heading + entry = """*** Heading :PROPERTIES: :ID: 42-42-42 :END: @@ -290,7 +290,7 @@ def test_entry_with_body_to_entry(tmp_path): def test_file_with_entry_after_intro_text_to_entry(tmp_path): "Ensure intro text before any headings is indexed." # Arrange - entry = f""" + entry = """ Intro text * Entry Heading @@ -312,7 +312,7 @@ Intro text def test_file_with_no_headings_to_entry(tmp_path): "Ensure files with no heading, only body text are loaded." # Arrange - entry = f""" + entry = """ - Bullet point 1 - Bullet point 2 """ @@ -332,7 +332,7 @@ def test_file_with_no_headings_to_entry(tmp_path): def test_extract_entries_with_different_level_headings(tmp_path): "Extract org entries with different level headings." # Arrange - entry = f""" + entry = """ * Heading 1 ** Sub-Heading 1.1 * Heading 2 @@ -396,14 +396,16 @@ def test_line_number_tracking_in_recursive_split(): assert entry.uri is not None, f"Entry '{entry}' has a None URI." assert match is not None, f"URI format is incorrect: {entry.uri}" - assert ( - filepath_from_uri == org_file_path - ), f"File path in URI '{filepath_from_uri}' does not match expected '{org_file_path}'" + assert filepath_from_uri == org_file_path, ( + f"File path in URI '{filepath_from_uri}' does not match expected '{org_file_path}'" + ) # Ensure the first non-heading line in the compiled entry matches the line in the file assert ( cleaned_first_entry_line in line_in_file.strip() or cleaned_first_entry_line in next_line_in_file.strip() - ), f"First non-heading line '{cleaned_first_entry_line}' in {entry.raw} does not match line {line_number_from_uri} in file: '{line_in_file}' or next line '{next_line_in_file}'" + ), ( + f"First non-heading line '{cleaned_first_entry_line}' in {entry.raw} does not match line {line_number_from_uri} in file: '{line_in_file}' or next line '{next_line_in_file}'" + ) # Helper Functions diff --git a/tests/test_orgnode.py b/tests/test_orgnode.py index 00c471b1..b3933ec2 100644 --- a/tests/test_orgnode.py +++ b/tests/test_orgnode.py @@ -8,7 +8,7 @@ from khoj.processor.content.org_mode import orgnode def test_parse_entry_with_no_headings(tmp_path): "Test parsing of entry with minimal fields" # Arrange - entry = f"""Body Line 1""" + entry = """Body Line 1""" orgfile = create_file(tmp_path, entry) # Act @@ -30,7 +30,7 @@ def test_parse_entry_with_no_headings(tmp_path): def test_parse_minimal_entry(tmp_path): "Test parsing of entry with minimal fields" # Arrange - entry = f""" + entry = """ * Heading Body Line 1""" orgfile = create_file(tmp_path, entry) @@ -54,7 +54,7 @@ Body Line 1""" def test_parse_complete_entry(tmp_path): "Test parsing of entry with all important fields" # Arrange - entry = f""" + entry = """ *** DONE [#A] Heading :Tag1:TAG2:tag3: CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun> :PROPERTIES: @@ -89,7 +89,7 @@ Body Line 2""" def test_render_entry_with_property_drawer_and_empty_body(tmp_path): "Render heading entry with property drawer" # Arrange - entry_to_render = f""" + entry_to_render = """ *** [#A] Heading1 :tag1: :PROPERTIES: :ID: 111-111-111-1111-1111 @@ -116,7 +116,7 @@ def test_render_entry_with_property_drawer_and_empty_body(tmp_path): def test_all_links_to_entry_rendered(tmp_path): "Ensure all links to entry rendered in property drawer from entry" # Arrange - entry = f""" + entry = """ *** [#A] Heading :tag1: :PROPERTIES: :ID: 123-456-789-4234-1231 @@ -133,7 +133,7 @@ Body Line 2 # Assert # SOURCE link rendered with Heading # ID link rendered with ID - assert f":ID: id:123-456-789-4234-1231" in f"{entries[0]}" + assert ":ID: id:123-456-789-4234-1231" in f"{entries[0]}" # LINE link rendered with line number assert f":LINE: file://{orgfile}#line=2" in f"{entries[0]}" # LINE link rendered with line number @@ -144,7 +144,7 @@ Body Line 2 def test_parse_multiple_entries(tmp_path): "Test parsing of multiple entries" # Arrange - content = f""" + content = """ *** FAILED [#A] Heading1 :tag1: CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun> :PROPERTIES: @@ -176,12 +176,12 @@ Body 2 # Assert assert len(entries) == 2 for index, entry in enumerate(entries): - assert entry.heading == f"Heading{index+1}" + assert entry.heading == f"Heading{index + 1}" assert entry.todo == "FAILED" if index == 0 else "CANCELLED" - assert entry.tags == [f"tag{index+1}"] - assert entry.body == f"- Clocked Log {index+1}\n\nBody {index+1}\n\n" + assert entry.tags == [f"tag{index + 1}"] + assert entry.body == f"- Clocked Log {index + 1}\n\nBody {index + 1}\n\n" assert entry.priority == "A" - assert entry.Property("ID") == f"id:123-456-789-4234-000{index+1}" + assert entry.Property("ID") == f"id:123-456-789-4234-000{index + 1}" assert entry.closed == datetime.date(1984, 4, index + 1) assert entry.scheduled == datetime.date(1984, 4, index + 1) assert entry.deadline == datetime.date(1984, 4, index + 1) @@ -194,7 +194,7 @@ Body 2 def test_parse_entry_with_empty_title(tmp_path): "Test parsing of entry with minimal fields" # Arrange - entry = f"""#+TITLE: + entry = """#+TITLE: Body Line 1""" orgfile = create_file(tmp_path, entry) @@ -217,7 +217,7 @@ Body Line 1""" def test_parse_entry_with_title_and_no_headings(tmp_path): "Test parsing of entry with minimal fields" # Arrange - entry = f"""#+TITLE: test + entry = """#+TITLE: test Body Line 1""" orgfile = create_file(tmp_path, entry) @@ -241,7 +241,7 @@ Body Line 1""" def test_parse_entry_with_multiple_titles_and_no_headings(tmp_path): "Test parsing of entry with minimal fields" # Arrange - entry = f"""#+TITLE: title1 + entry = """#+TITLE: title1 Body Line 1 #+TITLE: title2 """ orgfile = create_file(tmp_path, entry) @@ -266,7 +266,7 @@ Body Line 1 def test_parse_org_with_intro_text_before_heading(tmp_path): "Test parsing of org file with intro text before heading" # Arrange - body = f"""#+TITLE: Title + body = """#+TITLE: Title intro body * Entry Heading entry body @@ -290,7 +290,7 @@ entry body def test_parse_org_with_intro_text_multiple_titles_and_heading(tmp_path): "Test parsing of org file with intro text, multiple titles and heading entry" # Arrange - body = f"""#+TITLE: Title1 + body = """#+TITLE: Title1 intro body * Entry Heading entry body @@ -314,7 +314,7 @@ entry body def test_parse_org_with_single_ancestor_heading(tmp_path): "Parse org entries with parent headings context" # Arrange - body = f""" + body = """ * Heading 1 body 1 ** Sub Heading 1 @@ -336,7 +336,7 @@ body 1 def test_parse_org_with_multiple_ancestor_headings(tmp_path): "Parse org entries with parent headings context" # Arrange - body = f""" + body = """ * Heading 1 body 1 ** Sub Heading 1 @@ -362,7 +362,7 @@ sub sub body 1 def test_parse_org_with_multiple_ancestor_headings_of_siblings(tmp_path): "Parse org entries with parent headings context" # Arrange - body = f""" + body = """ * Heading 1 body 1 ** Sub Heading 1 diff --git a/tests/test_plaintext_to_entries.py b/tests/test_plaintext_to_entries.py index 558832d3..40cc3aa9 100644 --- a/tests/test_plaintext_to_entries.py +++ b/tests/test_plaintext_to_entries.py @@ -7,7 +7,7 @@ from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEnt def test_plaintext_file(): "Convert files with no heading to jsonl." # Arrange - raw_entry = f""" + raw_entry = """ Hi, I am a plaintext file and I have some plaintext words. """ plaintextfile = "test.txt" diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 9e532429..d462c34a 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -145,9 +145,9 @@ def test_entry_chunking_by_max_tokens(tmp_path, search_config, default_user: Kho text_search.setup(OrgToEntries, data, regenerate=False, user=default_user) # Assert - assert ( - "Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message - ), "new entry not split by max tokens" + assert "Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message, ( + "new entry not split by max tokens" + ) # ---------------------------------------------------------------------------------------------------- @@ -198,9 +198,9 @@ conda activate khoj ) # Assert - assert ( - "Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message - ), "new entry not split by max tokens" + assert "Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message, ( + "new entry not split by max tokens" + ) # ----------------------------------------------------------------------------------------------------