Format server code with ruff recommendations

This commit is contained in:
Debanjum
2025-08-01 00:10:34 -07:00
parent 4a3ed9e5a4
commit c8e07e86e4
65 changed files with 407 additions and 370 deletions

View File

@@ -14,6 +14,7 @@ Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.contrib import admin
from django.contrib.staticfiles.urls import staticfiles_urlpatterns
from django.urls import path

View File

@@ -1910,9 +1910,9 @@ class EntryAdapters:
owner_filter = Q()
if user != None:
if user is not None:
owner_filter = Q(user=user)
if agent != None:
if agent is not None:
owner_filter |= Q(agent=agent)
if owner_filter == Q():
@@ -1972,9 +1972,9 @@ class EntryAdapters:
):
owner_filter = Q()
if user != None:
if user is not None:
owner_filter = Q(user=user)
if agent != None:
if agent is not None:
owner_filter |= Q(agent=agent)
if owner_filter == Q():

View File

@@ -1,5 +1,4 @@
from django.core.management.base import BaseCommand
from django.db import transaction
from django.db.models import Exists, OuterRef
from khoj.database.models import Entry, FileObject

View File

@@ -41,7 +41,7 @@ def update_conversation_id_in_job_state(apps, schema_editor):
job.save()
except Conversation.DoesNotExist:
pass
except LookupError as e:
except LookupError:
pass

View File

@@ -1,6 +1,6 @@
# Made manually by sabaimran for use by Django 5.0.9 on 2024-12-01 16:59
from django.db import migrations, models
from django.db import migrations
# This script was written alongside when Pydantic validation was added to the Conversation conversation_log field.

View File

@@ -551,12 +551,12 @@ class TextToImageModelConfig(DbBaseModel):
error = {}
if self.model_type == self.ModelType.OPENAI:
if self.api_key and self.ai_model_api:
error[
"api_key"
] = "Both API key and AI Model API cannot be set for OpenAI models. Please set only one of them."
error[
"ai_model_api"
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
error["api_key"] = (
"Both API key and AI Model API cannot be set for OpenAI models. Please set only one of them."
)
error["ai_model_api"] = (
"Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
)
if self.model_type != self.ModelType.OPENAI and self.model_type != self.ModelType.GOOGLE:
if not self.api_key:
error["api_key"] = "The API key field must be set for non OpenAI, non Google models."

View File

@@ -1,3 +1 @@
from django.test import TestCase
# Create your tests here.

View File

@@ -1,5 +1,5 @@
""" Main module for Khoj
isort:skip_file
"""Main module for Khoj
isort:skip_file
"""
from contextlib import redirect_stdout
@@ -189,7 +189,7 @@ def run(should_start_server=True):
static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static")
if not os.path.exists(static_dir):
os.mkdir(static_dir)
app.mount(f"/static", StaticFiles(directory=static_dir), name=static_dir)
app.mount("/static", StaticFiles(directory=static_dir), name=static_dir)
# Configure Middleware
configure_middleware(app, state.ssl_config)

View File

@@ -1,5 +1,6 @@
#!/usr/bin/env python
"""Django's command-line utility for administrative tasks."""
import os
import sys

View File

@@ -51,7 +51,7 @@ class GithubToEntries(TextToEntries):
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
if is_none_or_empty(self.config.pat_token):
logger.warning(
f"Github PAT token is not set. Private repositories cannot be indexed and lower rate limits apply."
"Github PAT token is not set. Private repositories cannot be indexed and lower rate limits apply."
)
current_entries = []
for repo in self.config.repos:
@@ -137,7 +137,7 @@ class GithubToEntries(TextToEntries):
# Find all markdown files in the repository
if item["type"] == "blob" and item["path"].endswith(".md"):
# Create URL for each markdown file on Github
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}'
url_path = f"https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item['path']}"
# Add markdown file contents and URL to list
markdown_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}]
@@ -145,19 +145,19 @@ class GithubToEntries(TextToEntries):
# Find all org files in the repository
elif item["type"] == "blob" and item["path"].endswith(".org"):
# Create URL for each org file on Github
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}'
url_path = f"https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item['path']}"
# Add org file contents and URL to list
org_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}]
# Find, index remaining non-binary files in the repository
elif item["type"] == "blob":
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}'
url_path = f"https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item['path']}"
content_bytes = self.get_file_contents(item["url"], decode=False)
content_type, content_str = None, None
try:
content_type = magika.identify_bytes(content_bytes).output.group
except:
except Exception:
logger.error(f"Unable to identify content type of file at {url_path}. Skip indexing it")
continue
@@ -165,7 +165,7 @@ class GithubToEntries(TextToEntries):
if content_type in ["text", "code"]:
try:
content_str = content_bytes.decode("utf-8")
except:
except Exception:
logger.error(f"Unable to decode content of file at {url_path}. Skip indexing it")
continue
plaintext_files += [{"content": content_str, "path": url_path}]

View File

@@ -1,4 +1,3 @@
import base64
import logging
import os
from datetime import datetime, timezone

View File

@@ -1,6 +1,5 @@
import logging
import re
from pathlib import Path
from typing import Dict, List, Tuple
import urllib3.util
@@ -86,7 +85,7 @@ class MarkdownToEntries(TextToEntries):
# If content is small or content has no children headings, save it as a single entry
if len(TextToEntries.tokenizer(markdown_content_with_ancestry)) <= max_tokens or not re.search(
rf"^#{{{len(ancestry)+1},}}\s", markdown_content, flags=re.MULTILINE
rf"^#{{{len(ancestry) + 1},}}\s", markdown_content, flags=re.MULTILINE
):
# Create entry with line number information
entry_with_line_info = (markdown_content_with_ancestry, markdown_file, start_line)
@@ -160,7 +159,7 @@ class MarkdownToEntries(TextToEntries):
calculated_line = start_line if start_line > 0 else 1
# Check if raw_filename is a URL. If so, save it as is. If not, convert it to a Path.
if type(raw_filename) == str and re.search(r"^https?://", raw_filename):
if isinstance(raw_filename, str) and re.search(r"^https?://", raw_filename):
# Escape the URL to avoid issues with special characters
entry_filename = urllib3.util.parse_url(raw_filename).url
uri = entry_filename

View File

@@ -91,7 +91,7 @@ class NotionToEntries(TextToEntries):
json=self.body_params,
).json()
responses.append(result)
if result.get("has_more", False) == False:
if not result.get("has_more", False):
break
else:
self.body_params.update({"start_cursor": result["next_cursor"]})
@@ -118,7 +118,7 @@ class NotionToEntries(TextToEntries):
page_id = page["id"]
title, content = self.get_page_content(page_id)
if title == None or content == None:
if title is None or content is None:
return []
current_entries = []
@@ -126,11 +126,11 @@ class NotionToEntries(TextToEntries):
for block in content.get("results", []):
block_type = block.get("type")
if block_type == None:
if block_type is None:
continue
block_data = block[block_type]
if block_data.get("rich_text") == None or len(block_data["rich_text"]) == 0:
if block_data.get("rich_text") is None or len(block_data["rich_text"]) == 0:
# There's no text to handle here.
continue
@@ -179,7 +179,7 @@ class NotionToEntries(TextToEntries):
results = children.get("results", [])
for child in results:
child_type = child.get("type")
if child_type == None:
if child_type is None:
continue
child_data = child[child_type]
if child_data.get("rich_text") and len(child_data["rich_text"]) > 0:

View File

@@ -8,7 +8,6 @@ from khoj.database.models import KhojUser
from khoj.processor.content.org_mode import orgnode
from khoj.processor.content.org_mode.orgnode import Orgnode
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils import state
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry
@@ -103,7 +102,7 @@ class OrgToEntries(TextToEntries):
# If content is small or content has no children headings, save it as a single entry
# Note: This is the terminating condition for this recursive function
if len(TextToEntries.tokenizer(org_content_with_ancestry)) <= max_tokens or not re.search(
rf"^\*{{{len(ancestry)+1},}}\s", org_content, re.MULTILINE
rf"^\*{{{len(ancestry) + 1},}}\s", org_content, re.MULTILINE
):
orgnode_content_with_ancestry = orgnode.makelist(
org_content_with_ancestry, org_file, start_line=start_line, ancestry_lines=len(ancestry)
@@ -195,7 +194,7 @@ class OrgToEntries(TextToEntries):
if not entry_heading and parsed_entry.level > 0:
base_level = parsed_entry.level
# Indent entry by 1 heading level as ancestry is prepended as top level heading
heading = f"{'*' * (parsed_entry.level-base_level+2)} {todo_str}" if parsed_entry.level > 0 else ""
heading = f"{'*' * (parsed_entry.level - base_level + 2)} {todo_str}" if parsed_entry.level > 0 else ""
if parsed_entry.heading:
heading += f"{parsed_entry.heading}."
@@ -212,10 +211,10 @@ class OrgToEntries(TextToEntries):
compiled += f"\t {tags_str}."
if parsed_entry.closed:
compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.'
compiled += f"\n Closed on {parsed_entry.closed.strftime('%Y-%m-%d')}."
if parsed_entry.scheduled:
compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.'
compiled += f"\n Scheduled for {parsed_entry.scheduled.strftime('%Y-%m-%d')}."
if parsed_entry.hasBody:
compiled += f"\n {parsed_entry.body}"

View File

@@ -65,7 +65,7 @@ def makelist(file, filename, start_line: int = 1, ancestry_lines: int = 0) -> Li
"""
ctr = 0
if type(file) == str:
if isinstance(file, str):
f = file.splitlines()
else:
f = file
@@ -512,11 +512,11 @@ class Orgnode(object):
if self._closed or self._scheduled or self._deadline:
n = n + indent
if self._closed:
n = n + f'CLOSED: [{self._closed.strftime("%Y-%m-%d %a")}] '
n = n + f"CLOSED: [{self._closed.strftime('%Y-%m-%d %a')}] "
if self._scheduled:
n = n + f'SCHEDULED: <{self._scheduled.strftime("%Y-%m-%d %a")}> '
n = n + f"SCHEDULED: <{self._scheduled.strftime('%Y-%m-%d %a')}> "
if self._deadline:
n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> '
n = n + f"DEADLINE: <{self._deadline.strftime('%Y-%m-%d %a')}> "
if self._closed or self._scheduled or self._deadline:
n = n + "\n"

View File

@@ -1,6 +1,5 @@
import logging
import re
from pathlib import Path
from typing import Dict, List, Tuple
import urllib3
@@ -97,7 +96,7 @@ class PlaintextToEntries(TextToEntries):
for parsed_entry in parsed_entries:
raw_filename = entry_to_file_map[parsed_entry]
# Check if raw_filename is a URL. If so, save it as is. If not, convert it to a Path.
if type(raw_filename) == str and re.search(r"^https?://", raw_filename):
if isinstance(raw_filename, str) and re.search(r"^https?://", raw_filename):
# Escape the URL to avoid issues with special characters
entry_filename = urllib3.util.parse_url(raw_filename).url
else:

View File

@@ -30,8 +30,7 @@ class TextToEntries(ABC):
self.date_filter = DateFilter()
@abstractmethod
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
...
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: ...
@staticmethod
def hash_func(key: str) -> Callable:

View File

@@ -194,7 +194,7 @@ def gemini_completion_with_backoff(
or not response.candidates[0].content
or response.candidates[0].content.parts is None
):
raise ValueError(f"Failed to get response from model.")
raise ValueError("Failed to get response from model.")
raw_content = [part.model_dump() for part in response.candidates[0].content.parts]
if response.function_calls:
function_calls = [
@@ -212,7 +212,7 @@ def gemini_completion_with_backoff(
response = None
# Handle 429 rate limit errors directly
if e.code == 429:
response_text = f"My brain is exhausted. Can you please try again in a bit?"
response_text = "My brain is exhausted. Can you please try again in a bit?"
# Log the full error details for debugging
logger.error(f"Gemini ClientError: {e.code} {e.status}. Details: {e.details}")
# Handle other errors
@@ -361,7 +361,7 @@ def handle_gemini_response(
# Ensure we have a proper list of candidates
if not isinstance(candidates, list):
message = f"\nUnexpected response format. Try again."
message = "\nUnexpected response format. Try again."
stopped = True
return message, stopped

View File

@@ -2,7 +2,6 @@ import json
import logging
import os
from copy import deepcopy
from functools import partial
from time import perf_counter
from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union
from urllib.parse import urlparse
@@ -284,9 +283,9 @@ async def chat_completion_with_backoff(
if len(system_messages) > 0:
first_system_message_index, first_system_message = system_messages[0]
first_system_message_content = first_system_message["content"]
formatted_messages[first_system_message_index][
"content"
] = f"{first_system_message_content}\nFormatting re-enabled"
formatted_messages[first_system_message_index]["content"] = (
f"{first_system_message_content}\nFormatting re-enabled"
)
elif is_twitter_reasoning_model(model_name, api_base_url):
reasoning_effort = "high" if deepthought else "low"
# Grok-4 models do not support reasoning_effort parameter

View File

@@ -1,7 +1,6 @@
import base64
import json
import logging
import math
import mimetypes
import os
import re
@@ -18,7 +17,7 @@ import requests
import tiktoken
import yaml
from langchain_core.messages.chat import ChatMessage
from pydantic import BaseModel, ConfigDict, ValidationError, create_model
from pydantic import BaseModel, ConfigDict, ValidationError
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from khoj.database.adapters import ConversationAdapters
@@ -47,7 +46,11 @@ from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
try:
from git import Repo
import importlib.util
git_spec = importlib.util.find_spec("git")
if git_spec is None:
raise ImportError
except ImportError:
if is_promptrace_enabled():
logger.warning("GitPython not installed. `pip install gitpython` to use prompt tracer.")
@@ -294,7 +297,7 @@ def construct_chat_history_for_operator(conversation_history: List[ChatMessageMo
if chat.by == "you" and chat.message:
content = [{"type": "text", "text": chat.message}]
for file in chat.queryFiles or []:
content += [{"type": "text", "text": f'## File: {file["name"]}\n\n{file["content"]}'}]
content += [{"type": "text", "text": f"## File: {file['name']}\n\n{file['content']}"}]
user_message = AgentMessage(role="user", content=content)
elif chat.by == "khoj" and chat.message:
chat_history += [user_message, AgentMessage(role="assistant", content=chat.message)]
@@ -311,7 +314,10 @@ def construct_tool_chat_history(
If no tool is provided inferred query for all tools used are added.
"""
chat_history: list = []
base_extractor: Callable[[ResearchIteration], List[str]] = lambda iteration: []
def base_extractor(iteration: ResearchIteration) -> List[str]:
return []
extract_inferred_query_map: Dict[ConversationCommand, Callable[[ResearchIteration], List[str]]] = {
ConversationCommand.SemanticSearchFiles: (
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
@@ -498,7 +504,7 @@ async def save_to_conversation_log(
logger.info(
f"""
Saved Conversation Turn ({db_conversation.id if db_conversation else 'N/A'}):
Saved Conversation Turn ({db_conversation.id if db_conversation else "N/A"}):
You ({user.username}): "{q}"
Khoj: "{chat_response}"
@@ -625,7 +631,7 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(chat.operatorContext):
operator_context = chat.operatorContext
operator_content = "\n\n".join([f'## Task: {oc["query"]}\n{oc["response"]}\n' for oc in operator_context])
operator_content = "\n\n".join([f"## Task: {oc['query']}\n{oc['response']}\n" for oc in operator_context])
message_context += [
{
"type": "text",
@@ -744,7 +750,7 @@ def get_encoder(
else:
# as tiktoken doesn't recognize o1 model series yet
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
except:
except Exception:
encoder = tiktoken.encoding_for_model(default_tokenizer)
if state.verbose > 2:
logger.debug(
@@ -846,9 +852,9 @@ def truncate_messages(
total_tokens, _ = count_total_tokens(messages, encoder, system_message)
if total_tokens > max_prompt_size:
# At this point, a single message with a single content part of type dict should remain
assert (
len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict)
), "Expected a single message with a single content part remaining at this point in truncation"
assert len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict), (
"Expected a single message with a single content part remaining at this point in truncation"
)
# Collate message content into single string to ease truncation
part = messages[0].content[0]

View File

@@ -1,8 +1,6 @@
import logging
from typing import List
from urllib.parse import urlparse
import openai
import requests
import tqdm
from sentence_transformers import CrossEncoder, SentenceTransformer

View File

@@ -108,12 +108,12 @@ async def text_to_image(
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore
message = "Image generation blocked by OpenAI due to policy violation" # type: ignore
yield image_url or image, status_code, message
return
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed using OpenAI" # type: ignore
message = "Image generation failed using OpenAI" # type: ignore
status_code = e.status_code # type: ignore
yield image_url or image, status_code, message
return
@@ -199,7 +199,7 @@ def generate_image_with_stability(
# Call Stability AI API to generate image
response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
files={"none": ""},
data={

View File

@@ -11,7 +11,7 @@ from khoj.processor.conversation.utils import (
OperatorRun,
construct_chat_history_for_operator,
)
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_actions import RequestUserAction
from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent
from khoj.processor.operator.operator_agent_base import OperatorAgent
from khoj.processor.operator.operator_agent_binary import BinaryOperatorAgent
@@ -59,7 +59,7 @@ async def operate_environment(
if not reasoning_model or not reasoning_model.vision_enabled:
reasoning_model = await ConversationAdapters.aget_vision_enabled_config()
if not reasoning_model:
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.")
raise ValueError("No vision enabled chat model found. Configure a vision chat model to operate environment.")
# Create conversation history from conversation log
chat_history = construct_chat_history_for_operator(conversation_log)

View File

@@ -1,14 +1,27 @@
import json
import logging
from textwrap import dedent
from typing import List, Optional
from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import construct_structured_message
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_base import AgentActResult
from khoj.processor.operator.operator_actions import (
BackAction,
ClickAction,
DoubleClickAction,
DragAction,
GotoAction,
KeypressAction,
OperatorAction,
Point,
ScreenshotAction,
ScrollAction,
TypeAction,
WaitAction,
)
from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState
from khoj.utils.helpers import get_chat_usage_metrics

View File

@@ -18,7 +18,22 @@ from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletion
from PIL import Image
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_actions import (
BackAction,
ClickAction,
DoubleClickAction,
DragAction,
GotoAction,
KeyDownAction,
KeypressAction,
KeyUpAction,
MoveAction,
OperatorAction,
RequestUserAction,
ScrollAction,
TypeAction,
WaitAction,
)
from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState
from khoj.utils.helpers import get_chat_usage_metrics
@@ -122,11 +137,10 @@ class GroundingAgentUitars:
)
temperature = self.temperature
top_k = self.top_k
try_times = 3
while not parsed_responses:
if try_times <= 0:
logger.warning(f"Reach max retry times to fetch response from client, as error flag.")
logger.warning("Reach max retry times to fetch response from client, as error flag.")
return "client error\nFAIL", []
try:
message_content = "\n".join([msg["content"][0].get("text") or "[image]" for msg in messages])
@@ -163,7 +177,6 @@ class GroundingAgentUitars:
prediction = None
try_times -= 1
temperature = 1
top_k = -1
if prediction is None:
return "client error\nFAIL", []
@@ -264,9 +277,9 @@ class GroundingAgentUitars:
raise ValueError(f"Unsupported environment type: {environment_type}")
def _format_messages_for_api(self, instruction: str, current_state: EnvState):
assert len(self.observations) == len(self.actions) and len(self.actions) == len(
self.thoughts
), "The number of observations and actions should be the same."
assert len(self.observations) == len(self.actions) and len(self.actions) == len(self.thoughts), (
"The number of observations and actions should be the same."
)
self.history_images.append(base64.b64decode(current_state.screenshot))
self.observations.append({"screenshot": current_state.screenshot, "accessibility_tree": None})
@@ -524,7 +537,7 @@ class GroundingAgentUitars:
parsed_actions = [self.parse_action_string(action.replace("\n", "\\n").lstrip()) for action in all_action]
actions: list[dict] = []
for action_instance, raw_str in zip(parsed_actions, all_action):
if action_instance == None:
if action_instance is None:
print(f"Action can't parse: {raw_str}")
raise ValueError(f"Action can't parse: {raw_str}")
action_type = action_instance["function"]
@@ -756,7 +769,7 @@ class GroundingAgentUitars:
The pyautogui code string
"""
pyautogui_code = f"import pyautogui\nimport time\n"
pyautogui_code = "import pyautogui\nimport time\n"
actions = []
if isinstance(responses, dict):
responses = [responses]
@@ -774,7 +787,7 @@ class GroundingAgentUitars:
if response_id == 0:
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
else:
pyautogui_code += f"\ntime.sleep(1)\n"
pyautogui_code += "\ntime.sleep(1)\n"
action_dict = response
action_type = action_dict.get("action_type")
@@ -846,17 +859,17 @@ class GroundingAgentUitars:
if content:
if input_swap:
actions += TypeAction()
pyautogui_code += f"\nimport pyperclip"
pyautogui_code += "\nimport pyperclip"
pyautogui_code += f"\npyperclip.copy('{stripped_content}')"
pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')"
pyautogui_code += f"\ntime.sleep(0.5)\n"
pyautogui_code += "\npyautogui.hotkey('ctrl', 'v')"
pyautogui_code += "\ntime.sleep(0.5)\n"
if content.endswith("\n") or content.endswith("\\n"):
pyautogui_code += f"\npyautogui.press('enter')"
pyautogui_code += "\npyautogui.press('enter')"
else:
pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)"
pyautogui_code += f"\ntime.sleep(0.5)\n"
pyautogui_code += "\ntime.sleep(0.5)\n"
if content.endswith("\n") or content.endswith("\\n"):
pyautogui_code += f"\npyautogui.press('enter')"
pyautogui_code += "\npyautogui.press('enter')"
elif action_type in ["drag", "select"]:
# Parsing drag or select action based on start and end_boxes
@@ -869,9 +882,7 @@ class GroundingAgentUitars:
x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2]
ex = round(float((x1 + x2) / 2) * image_width, 3)
ey = round(float((y1 + y2) / 2) * image_height, 3)
pyautogui_code += (
f"\npyautogui.moveTo({sx}, {sy})\n" f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
)
pyautogui_code += f"\npyautogui.moveTo({sx}, {sy})\n\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
elif action_type == "scroll":
# Parsing scroll action
@@ -888,11 +899,11 @@ class GroundingAgentUitars:
y = None
direction = action_inputs.get("direction", "")
if x == None:
if x is None:
if "up" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(5)"
pyautogui_code += "\npyautogui.scroll(5)"
elif "down" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(-5)"
pyautogui_code += "\npyautogui.scroll(-5)"
else:
if "up" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})"
@@ -923,7 +934,7 @@ class GroundingAgentUitars:
pyautogui_code += f"\npyautogui.moveTo({x}, {y})"
elif action_type in ["finished"]:
pyautogui_code = f"DONE"
pyautogui_code = "DONE"
else:
pyautogui_code += f"\n# Unrecognized action type: {action_type}"

View File

@@ -11,7 +11,32 @@ from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlo
from khoj.database.models import ChatModel
from khoj.processor.conversation.anthropic.utils import is_reasoning_model
from khoj.processor.conversation.utils import AgentMessage
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_actions import (
BackAction,
ClickAction,
CursorPositionAction,
DoubleClickAction,
DragAction,
GotoAction,
HoldKeyAction,
KeypressAction,
MouseDownAction,
MouseUpAction,
MoveAction,
NoopAction,
OperatorAction,
Point,
ScreenshotAction,
ScrollAction,
TerminalAction,
TextEditorCreateAction,
TextEditorInsertAction,
TextEditorStrReplaceAction,
TextEditorViewAction,
TripleClickAction,
TypeAction,
WaitAction,
)
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
@@ -518,7 +543,7 @@ class AnthropicOperatorAgent(OperatorAgent):
def model_default_headers(self) -> list[str]:
"""Get the default computer use headers for the given model."""
if self.vision_model.name.startswith("claude-3-7-sonnet"):
return [f"computer-use-2025-01-24", "token-efficient-tools-2025-02-19"]
return ["computer-use-2025-01-24", "token-efficient-tools-2025-02-19"]
elif self.vision_model.name.startswith("claude-sonnet-4") or self.vision_model.name.startswith("claude-opus-4"):
return ["computer-use-2025-01-24"]
else:
@@ -538,7 +563,7 @@ class AnthropicOperatorAgent(OperatorAgent):
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
* The current URL is {current_state.url}.
</SYSTEM_CAPABILITY>
@@ -563,7 +588,7 @@ class AnthropicOperatorAgent(OperatorAgent):
</SYSTEM_CAPABILITY>
<CONTEXT>
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
</CONTEXT>
"""
).lstrip()

View File

@@ -1,6 +1,6 @@
import logging
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Union
from typing import List, Optional
from pydantic import BaseModel

View File

@@ -12,7 +12,7 @@ from khoj.processor.conversation.utils import (
)
from khoj.processor.operator.grounding_agent import GroundingAgent
from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_actions import OperatorAction, WaitAction
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
@@ -181,7 +181,7 @@ class BinaryOperatorAgent(OperatorAgent):
elif action.type == "key_down":
rendered_parts += [f'**Action**: Press Key "{action.key}"']
elif action.type == "screenshot" and not current_state.screenshot:
rendered_parts += [f"**Error**: Failed to take screenshot"]
rendered_parts += ["**Error**: Failed to take screenshot"]
elif action.type == "goto":
rendered_parts += [f"**Action**: Open URL {action.url}"]
else:
@@ -317,7 +317,7 @@ class BinaryOperatorAgent(OperatorAgent):
# Introduction
* You are Khoj, a smart and resourceful web browsing assistant. You help the user accomplish their task using a web browser.
* You are given the user's query and screenshots of the browser's state transitions.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
* The current URL is {env_state.url}.
# Your Task
@@ -362,7 +362,7 @@ class BinaryOperatorAgent(OperatorAgent):
# Introduction
* You are Khoj, a smart and resourceful computer assistant. You help the user accomplish their task using a computer.
* You are given the user's query and screenshots of the computer's state transitions.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
# Your Task
* First look at the screenshots carefully to notice all pertinent information.

View File

@@ -1,6 +1,5 @@
import json
import logging
import platform
from copy import deepcopy
from datetime import datetime
from textwrap import dedent
@@ -10,7 +9,23 @@ from openai.types.responses import Response, ResponseOutputItem
from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import AgentMessage
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_actions import (
BackAction,
ClickAction,
DoubleClickAction,
DragAction,
GotoAction,
KeypressAction,
MoveAction,
NoopAction,
OperatorAction,
Point,
RequestUserAction,
ScreenshotAction,
ScrollAction,
TypeAction,
WaitAction,
)
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
@@ -152,7 +167,7 @@ class OpenAIOperatorAgent(OperatorAgent):
# Add screenshot data in openai message format
action_result["output"] = {
"type": "input_image",
"image_url": f'data:image/webp;base64,{result_content["image"]}',
"image_url": f"data:image/webp;base64,{result_content['image']}",
"current_url": result_content["url"],
}
elif action_result["type"] == "computer_call_output" and idx == len(env_steps) - 1:
@@ -311,7 +326,7 @@ class OpenAIOperatorAgent(OperatorAgent):
elif block.type == "function_call":
if block.name == "goto":
args = json.loads(block.arguments)
render_texts = [f'Open URL: {args.get("url", "[Missing URL]")}']
render_texts = [f"Open URL: {args.get('url', '[Missing URL]')}"]
else:
render_texts += [block.name]
elif block.type == "computer_call":
@@ -351,7 +366,7 @@ class OpenAIOperatorAgent(OperatorAgent):
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
* The current URL is {current_state.url}.
</SYSTEM_CAPABILITY>
@@ -374,7 +389,7 @@ class OpenAIOperatorAgent(OperatorAgent):
</SYSTEM_CAPABILITY>
<CONTEXT>
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
</CONTEXT>
"""
).lstrip()

View File

@@ -247,7 +247,7 @@ class BrowserEnvironment(Environment):
case "drag":
if not isinstance(action, DragAction):
raise TypeError(f"Invalid action type for drag")
raise TypeError("Invalid action type for drag")
path = action.path
if not path:
error = "Missing path for drag action"

View File

@@ -532,7 +532,7 @@ class ComputerEnvironment(Environment):
else:
return {"success": False, "output": process.stdout, "error": process.stderr}
except asyncio.TimeoutError:
return {"success": False, "output": "", "error": f"Command timed out after 120 seconds."}
return {"success": False, "output": "", "error": "Command timed out after 120 seconds."}
except Exception as e:
return {"success": False, "output": "", "error": str(e)}

View File

@@ -1,4 +1,3 @@
import json # Used for working with JSON data
import os
import requests # Used for making HTTP requests

View File

@@ -385,7 +385,7 @@ async def read_webpages(
tracer: dict = {},
):
"Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read")
logger.info("Inferring web pages to read")
urls = await infer_webpage_urls(
query,
max_webpages_to_read,

View File

@@ -93,7 +93,7 @@ async def run_code(
# Run Code
if send_status_func:
async for event in send_status_func(f"**Running code snippet**"):
async for event in send_status_func("**Running code snippet**"):
yield {ChatEvent.STATUS: event}
try:
with timer("Chat actor: Execute generated program", logger, log_level=logging.INFO):

View File

@@ -7,7 +7,6 @@ from typing import List, Optional, Union
import openai
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from fastapi.requests import Request
from fastapi.responses import Response
from starlette.authentication import has_required_scope, requires
@@ -94,7 +93,7 @@ def update(
logger.error(error_msg, exc_info=True)
raise HTTPException(status_code=500, detail=error_msg)
else:
logger.info(f"📪 Server indexed content updated via API")
logger.info("📪 Server indexed content updated via API")
update_telemetry_state(
request=request,

View File

@@ -6,12 +6,11 @@ from typing import Dict, List, Optional
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Request
from fastapi.requests import Request
from fastapi.responses import Response
from pydantic import BaseModel
from starlette.authentication import has_required_scope, requires
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
from khoj.database.adapters import AgentAdapters, ConversationAdapters
from khoj.database.models import Agent, Conversation, KhojUser, PriceTier
from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
from khoj.utils.helpers import (

View File

@@ -109,7 +109,7 @@ def post_automation(
except Exception as e:
logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True)
return Response(
content=f"Unable to create automation. Ensure the automation doesn't already exist.",
content="Unable to create automation. Ensure the automation doesn't already exist.",
media_type="text/plain",
status_code=500,
)

View File

@@ -10,7 +10,6 @@ from functools import partial
from typing import Any, Dict, List, Optional
from urllib.parse import unquote
from asgiref.sync import sync_to_async
from fastapi import (
APIRouter,
Depends,
@@ -32,10 +31,10 @@ from khoj.database.adapters import (
PublicConversationAdapters,
aget_user_name,
)
from khoj.database.models import Agent, ChatMessageModel, KhojUser
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import is_local_api
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.prompts import no_entries_found
from khoj.processor.conversation.utils import (
OperatorRun,
ResponseWithThought,
@@ -65,11 +64,8 @@ from khoj.routers.helpers import (
acreate_title_from_history,
agenerate_chat_response,
aget_data_sources_and_output_format,
construct_automation_created_message,
create_automation,
gather_raw_query_files,
generate_mermaidjs_diagram,
generate_summary_from_files,
get_conversation_command,
get_message_from_queue,
is_query_empty,
@@ -89,13 +85,11 @@ from khoj.utils.helpers import (
convert_image_to_webp,
get_country_code_from_timezone,
get_country_name_from_timezone,
get_device,
is_env_var_true,
is_none_or_empty,
is_operator_enabled,
)
from khoj.utils.rawconfig import (
ChatRequestBody,
FileAttachment,
FileFilterRequest,
FilesFilterRequest,
@@ -689,7 +683,6 @@ async def event_generator(
region = body.region
country = body.country or get_country_name_from_timezone(body.timezone)
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
timezone = body.timezone
raw_images = body.images
raw_query_files = body.files
@@ -853,7 +846,8 @@ async def event_generator(
if (
len(train_of_thought) > 0
and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
and type(train_of_thought[-1]["data"]) == type(data) == str
and isinstance(train_of_thought[-1]["data"], str)
and isinstance(data, str)
):
train_of_thought[-1]["data"] += data
else:
@@ -1075,11 +1069,11 @@ async def event_generator(
# researched_results = await extract_relevant_info(q, researched_results, agent)
if state.verbose > 1:
logger.debug(f'Researched Results: {"".join(r.summarizedResult or "" for r in research_results)}')
logger.debug(f"Researched Results: {''.join(r.summarizedResult or '' for r in research_results)}")
# Gather Context
## Extract Document References
if not ConversationCommand.Research in conversation_commands:
if ConversationCommand.Research not in conversation_commands:
try:
async for result in search_documents(
q,
@@ -1218,7 +1212,7 @@ async def event_generator(
else:
code_results = result
except ValueError as e:
program_execution_context.append(f"Failed to run code")
program_execution_context.append("Failed to run code")
logger.warning(
f"Failed to use code tool: {e}. Attempting to respond without code results",
exc_info=True,
@@ -1297,7 +1291,7 @@ async def event_generator(
inferred_queries.append(improved_image_prompt)
if generated_image is None or status_code != 200:
program_execution_context.append(f"Failed to generate image with {improved_image_prompt}")
async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"):
async for result in send_event(ChatEvent.STATUS, "Failed to generate image"):
yield result
else:
generated_images.append(generated_image)
@@ -1315,7 +1309,7 @@ async def event_generator(
yield result
if ConversationCommand.Diagram in conversation_commands:
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
async for result in send_event(ChatEvent.STATUS, "Creating diagram"):
yield result
inferred_queries = []
@@ -1372,7 +1366,7 @@ async def event_generator(
return
## Generate Text Output
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
async for result in send_event(ChatEvent.STATUS, "**Generating a well-informed response**"):
yield result
llm_response, chat_metadata = await agenerate_chat_response(

View File

@@ -3,7 +3,6 @@ import logging
from typing import Dict, Optional, Union
from fastapi import APIRouter, Request
from fastapi.requests import Request
from fastapi.responses import Response
from starlette.authentication import has_required_scope, requires

View File

@@ -117,7 +117,7 @@ async def subscribe(request: Request):
)
logger.log(logging.INFO, f"🥳 New User Created: {user.user.uuid}")
logger.info(f'Stripe subscription {event["type"]} for {customer_email}')
logger.info(f"Stripe subscription {event['type']} for {customer_email}")
return {"success": success}

View File

@@ -44,7 +44,7 @@ async def send_magic_link_email(email, unique_id, host):
{
"sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"),
"to": email,
"subject": f"Your login code to Khoj",
"subject": "Your login code to Khoj",
"html": html_content,
}
)
@@ -98,11 +98,11 @@ async def send_query_feedback(uquery, kquery, sentiment, user_email):
user_email=user_email if not is_none_or_empty(user_email) else "N/A",
)
# send feedback to fixed account
r = resend.Emails.send(
resend.Emails.send(
{
"sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"),
"to": "team@khoj.dev",
"subject": f"User Feedback",
"subject": "User Feedback",
"html": html_content,
}
)
@@ -127,7 +127,7 @@ def send_task_email(name, email, query, result, subject, is_image=False):
r = resend.Emails.send(
{
"sender": f'Khoj <{os.environ.get("RESEND_EMAIL", "khoj@khoj.dev")}>',
"sender": f"Khoj <{os.environ.get('RESEND_EMAIL', 'khoj@khoj.dev')}>",
"to": email,
"subject": f"{subject}",
"html": html_content,

View File

@@ -1,6 +1,5 @@
import asyncio
import base64
import concurrent.futures
import fnmatch
import hashlib
import json
@@ -47,14 +46,12 @@ from khoj.database.adapters import (
EntryAdapters,
FileObjectAdapters,
aget_user_by_email,
ais_user_subscribed,
create_khoj_token,
get_default_search_model,
get_khoj_tokens,
get_user_name,
get_user_notion_config,
get_user_subscription_state,
is_user_subscribed,
run_with_process_lock,
)
from khoj.database.models import (
@@ -160,7 +157,7 @@ def validate_chat_model(user: KhojUser):
async def is_ready_to_chat(user: KhojUser):
user_chat_model = await ConversationAdapters.aget_user_chat_model(user)
if user_chat_model == None:
if user_chat_model is None:
user_chat_model = await ConversationAdapters.aget_default_chat_model(user)
if (
@@ -581,7 +578,7 @@ async def generate_online_subqueries(
)
return {q}
return response
except Exception as e:
except Exception:
logger.error(f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}")
return {q}
@@ -1172,8 +1169,8 @@ async def search_documents(
agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent)
if (
not ConversationCommand.Notes in conversation_commands
and not ConversationCommand.Default in conversation_commands
ConversationCommand.Notes not in conversation_commands
and ConversationCommand.Default not in conversation_commands
and not agent_has_entries
):
yield compiled_references, inferred_queries, q
@@ -1325,8 +1322,8 @@ async def extract_questions(
logger.error(f"Invalid response for constructing subqueries: {response}")
return [query]
return queries
except:
logger.warning(f"LLM returned invalid JSON. Falling back to using user message as search query.")
except Exception:
logger.warning("LLM returned invalid JSON. Falling back to using user message as search query.")
return [query]
@@ -1351,7 +1348,7 @@ async def execute_search(
return results
if q is None or q == "":
logger.warning(f"No query param (q) passed in API call to initiate search")
logger.warning("No query param (q) passed in API call to initiate search")
return results
# initialize variables
@@ -1364,7 +1361,7 @@ async def execute_search(
if user:
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
if query_cache_key in state.query_cache[user.uuid]:
logger.debug(f"Return response from query cache")
logger.debug("Return response from query cache")
return state.query_cache[user.uuid][query_cache_key]
# Encode query with filter terms removed
@@ -1875,8 +1872,8 @@ class ApiUserRateLimiter:
user: KhojUser = websocket.scope["user"].object
subscribed = has_required_scope(websocket, ["premium"])
current_window = "today" if self.window == 60 * 60 * 24 else f"now"
next_window = "tomorrow" if self.window == 60 * 60 * 24 else f"in a bit"
current_window = "today" if self.window == 60 * 60 * 24 else "now"
next_window = "tomorrow" if self.window == 60 * 60 * 24 else "in a bit"
common_message_prefix = f"I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for {current_window}."
# Remove requests outside of the time window
@@ -2219,7 +2216,7 @@ def should_notify(original_query: str, executed_query: str, ai_response: str, us
should_notify_result = response["decision"] == "Yes"
reason = response.get("reason", "unknown")
logger.info(
f'Decided to {"not " if not should_notify_result else ""}notify user of automation response because of reason: {reason}.'
f"Decided to {'not ' if not should_notify_result else ''}notify user of automation response because of reason: {reason}."
)
return should_notify_result
except Exception as e:
@@ -2313,7 +2310,7 @@ def scheduled_chat(
response_map = raw_response.json()
ai_response = response_map.get("response") or response_map.get("image")
is_image = False
if type(ai_response) == dict:
if isinstance(ai_response, dict):
is_image = ai_response.get("image") is not None
else:
ai_response = raw_response.text
@@ -2460,12 +2457,12 @@ async def aschedule_automation(
def construct_automation_created_message(automation: Job, crontime: str, query_to_run: str, subject: str):
# Display next run time in user timezone instead of UTC
schedule = f'{cron_descriptor.get_description(crontime)} {automation.next_run_time.strftime("%Z")}'
schedule = f"{cron_descriptor.get_description(crontime)} {automation.next_run_time.strftime('%Z')}"
next_run_time = automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z")
# Remove /automated_task prefix from inferred_query
unprefixed_query_to_run = re.sub(r"^\/automated_task\s*", "", query_to_run)
# Create the automation response
automation_icon_url = f"/static/assets/icons/automation.svg"
automation_icon_url = "/static/assets/icons/automation.svg"
return f"""
### ![]({automation_icon_url}) Created Automation
- Subject: **{subject}**
@@ -2713,13 +2710,13 @@ def configure_content(
t: Optional[state.SearchType] = state.SearchType.All,
) -> bool:
success = True
if t == None:
if t is None:
t = state.SearchType.All
if t is not None and t in [type.value for type in state.SearchType]:
t = state.SearchType(t)
if t is not None and not t.value in [type.value for type in state.SearchType]:
if t is not None and t.value not in [type.value for type in state.SearchType]:
logger.warning(f"🚨 Invalid search type: {t}")
return False
@@ -2988,7 +2985,7 @@ async def grep_files(
query += f" {' and '.join(context_info)}"
if line_count > max_results:
if lines_before or lines_after:
query += f" for"
query += " for"
query += f" first {max_results} results"
return query

View File

@@ -15,7 +15,6 @@ from khoj.processor.conversation.utils import (
ResearchIteration,
ToolCall,
construct_iteration_history,
construct_structured_message,
construct_tool_chat_history,
load_complex_json,
)
@@ -24,7 +23,6 @@ from khoj.processor.tools.online_search import read_webpages_content, search_onl
from khoj.processor.tools.run_code import run_code
from khoj.routers.helpers import (
ChatEvent,
generate_summary_from_files,
get_message_from_queue,
grep_files,
list_files,
@@ -184,7 +182,7 @@ async def apick_next_tool(
# TODO: Handle multiple tool calls.
response_text = response.text
parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0]
except Exception as e:
except Exception:
# Otherwise assume the model has decided to end the research run and respond to the user.
parsed_response = ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None)
@@ -199,7 +197,7 @@ async def apick_next_tool(
if i.warning is None and isinstance(i.query, ToolCall)
}
if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
warning = "Repeated tool, query combination detected. Skipping iteration. Try something different."
# Only send client status updates if we'll execute this iteration and model has thoughts to share.
elif send_status_func and not is_none_or_empty(response.thought):
async for event in send_status_func(response.thought):

View File

@@ -4,12 +4,10 @@ from typing import List
class BaseFilter(ABC):
@abstractmethod
def get_filter_terms(self, query: str) -> List[str]:
...
def get_filter_terms(self, query: str) -> List[str]: ...
def can_filter(self, raw_query: str) -> bool:
return len(self.get_filter_terms(raw_query)) > 0
@abstractmethod
def defilter(self, query: str) -> str:
...
def defilter(self, query: str) -> str: ...

View File

@@ -9,9 +9,8 @@ from asgiref.sync import sync_to_async
from sentence_transformers import util
from khoj.database.adapters import EntryAdapters, get_default_search_model
from khoj.database.models import Agent
from khoj.database.models import Agent, KhojUser
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils import state
from khoj.utils.helpers import get_absolute_path, timer

View File

@@ -77,7 +77,7 @@ class AsyncIteratorWrapper:
def is_none_or_empty(item):
return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
return item is None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
def to_snake_case_from_dash(item: str):
@@ -97,7 +97,7 @@ def get_from_dict(dictionary, *args):
Returns: dictionary[args[0]][args[1]]... or None if any keys missing"""
current = dictionary
for arg in args:
if not hasattr(current, "__iter__") or not arg in current:
if not hasattr(current, "__iter__") or arg not in current:
return None
current = current[arg]
return current
@@ -751,7 +751,7 @@ def is_valid_url(url: str) -> bool:
try:
result = urlparse(url.strip())
return all([result.scheme, result.netloc])
except:
except Exception:
return False
@@ -759,7 +759,7 @@ def is_internet_connected():
try:
response = requests.head("https://www.google.com")
return response.status_code == 200
except:
except Exception:
return False

View File

@@ -60,9 +60,7 @@ def initialization(interactive: bool = True):
]
default_chat_models = known_available_models + other_available_models
except Exception as e:
logger.warning(
f"⚠️ Failed to fetch {provider} chat models. Fallback to default models. Error: {str(e)}"
)
logger.warning(f"⚠️ Failed to fetch {provider} chat models. Fallback to default models. Error: {str(e)}")
# Set up OpenAI's online chat models
openai_configured, openai_provider = _setup_chat_model_provider(

View File

@@ -8,12 +8,10 @@ from tqdm import trange
class BaseEncoder(ABC):
@abstractmethod
def __init__(self, model_name: str, device: torch.device = None, **kwargs):
...
def __init__(self, model_name: str, device: torch.device = None, **kwargs): ...
@abstractmethod
def encode(self, entries: List[str], device: torch.device = None, **kwargs) -> torch.Tensor:
...
def encode(self, entries: List[str], device: torch.device = None, **kwargs) -> torch.Tensor: ...
class OpenAI(BaseEncoder):

View File

@@ -1,8 +1,7 @@
# System Packages
import json
import uuid
from pathlib import Path
from typing import Dict, List, Optional
from typing import List, Optional
from pydantic import BaseModel

View File

@@ -2,7 +2,7 @@ import os
import threading
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List
from typing import Dict, List
from apscheduler.schedulers.background import BackgroundScheduler
from openai import OpenAI

View File

@@ -30,7 +30,7 @@ def v1_telemetry(telemetry_data: List[Dict[str, str]]):
try:
for row in telemetry_data:
posthog.capture(row["server_id"], "api_request", row)
except Exception as e:
except Exception:
raise HTTPException(
status_code=500,
detail="Could not POST equest to new khoj telemetry server. Contact developer to get this fixed.",