mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Format server code with ruff recommendations
This commit is contained in:
@@ -14,6 +14,7 @@ Including another URLconf
|
||||
1. Import the include() function: from django.urls import include, path
|
||||
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
||||
"""
|
||||
|
||||
from django.contrib import admin
|
||||
from django.contrib.staticfiles.urls import staticfiles_urlpatterns
|
||||
from django.urls import path
|
||||
|
||||
@@ -1910,9 +1910,9 @@ class EntryAdapters:
|
||||
|
||||
owner_filter = Q()
|
||||
|
||||
if user != None:
|
||||
if user is not None:
|
||||
owner_filter = Q(user=user)
|
||||
if agent != None:
|
||||
if agent is not None:
|
||||
owner_filter |= Q(agent=agent)
|
||||
|
||||
if owner_filter == Q():
|
||||
@@ -1972,9 +1972,9 @@ class EntryAdapters:
|
||||
):
|
||||
owner_filter = Q()
|
||||
|
||||
if user != None:
|
||||
if user is not None:
|
||||
owner_filter = Q(user=user)
|
||||
if agent != None:
|
||||
if agent is not None:
|
||||
owner_filter |= Q(agent=agent)
|
||||
|
||||
if owner_filter == Q():
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import transaction
|
||||
from django.db.models import Exists, OuterRef
|
||||
|
||||
from khoj.database.models import Entry, FileObject
|
||||
|
||||
@@ -41,7 +41,7 @@ def update_conversation_id_in_job_state(apps, schema_editor):
|
||||
job.save()
|
||||
except Conversation.DoesNotExist:
|
||||
pass
|
||||
except LookupError as e:
|
||||
except LookupError:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Made manually by sabaimran for use by Django 5.0.9 on 2024-12-01 16:59
|
||||
|
||||
from django.db import migrations, models
|
||||
from django.db import migrations
|
||||
|
||||
# This script was written alongside when Pydantic validation was added to the Conversation conversation_log field.
|
||||
|
||||
|
||||
@@ -551,12 +551,12 @@ class TextToImageModelConfig(DbBaseModel):
|
||||
error = {}
|
||||
if self.model_type == self.ModelType.OPENAI:
|
||||
if self.api_key and self.ai_model_api:
|
||||
error[
|
||||
"api_key"
|
||||
] = "Both API key and AI Model API cannot be set for OpenAI models. Please set only one of them."
|
||||
error[
|
||||
"ai_model_api"
|
||||
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
|
||||
error["api_key"] = (
|
||||
"Both API key and AI Model API cannot be set for OpenAI models. Please set only one of them."
|
||||
)
|
||||
error["ai_model_api"] = (
|
||||
"Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
|
||||
)
|
||||
if self.model_type != self.ModelType.OPENAI and self.model_type != self.ModelType.GOOGLE:
|
||||
if not self.api_key:
|
||||
error["api_key"] = "The API key field must be set for non OpenAI, non Google models."
|
||||
|
||||
@@ -1,3 +1 @@
|
||||
from django.test import TestCase
|
||||
|
||||
# Create your tests here.
|
||||
|
||||
@@ -189,7 +189,7 @@ def run(should_start_server=True):
|
||||
static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static")
|
||||
if not os.path.exists(static_dir):
|
||||
os.mkdir(static_dir)
|
||||
app.mount(f"/static", StaticFiles(directory=static_dir), name=static_dir)
|
||||
app.mount("/static", StaticFiles(directory=static_dir), name=static_dir)
|
||||
|
||||
# Configure Middleware
|
||||
configure_middleware(app, state.ssl_config)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
"""Django's command-line utility for administrative tasks."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class GithubToEntries(TextToEntries):
|
||||
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||
if is_none_or_empty(self.config.pat_token):
|
||||
logger.warning(
|
||||
f"Github PAT token is not set. Private repositories cannot be indexed and lower rate limits apply."
|
||||
"Github PAT token is not set. Private repositories cannot be indexed and lower rate limits apply."
|
||||
)
|
||||
current_entries = []
|
||||
for repo in self.config.repos:
|
||||
@@ -137,7 +137,7 @@ class GithubToEntries(TextToEntries):
|
||||
# Find all markdown files in the repository
|
||||
if item["type"] == "blob" and item["path"].endswith(".md"):
|
||||
# Create URL for each markdown file on Github
|
||||
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}'
|
||||
url_path = f"https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item['path']}"
|
||||
|
||||
# Add markdown file contents and URL to list
|
||||
markdown_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}]
|
||||
@@ -145,19 +145,19 @@ class GithubToEntries(TextToEntries):
|
||||
# Find all org files in the repository
|
||||
elif item["type"] == "blob" and item["path"].endswith(".org"):
|
||||
# Create URL for each org file on Github
|
||||
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}'
|
||||
url_path = f"https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item['path']}"
|
||||
|
||||
# Add org file contents and URL to list
|
||||
org_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}]
|
||||
|
||||
# Find, index remaining non-binary files in the repository
|
||||
elif item["type"] == "blob":
|
||||
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}'
|
||||
url_path = f"https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item['path']}"
|
||||
content_bytes = self.get_file_contents(item["url"], decode=False)
|
||||
content_type, content_str = None, None
|
||||
try:
|
||||
content_type = magika.identify_bytes(content_bytes).output.group
|
||||
except:
|
||||
except Exception:
|
||||
logger.error(f"Unable to identify content type of file at {url_path}. Skip indexing it")
|
||||
continue
|
||||
|
||||
@@ -165,7 +165,7 @@ class GithubToEntries(TextToEntries):
|
||||
if content_type in ["text", "code"]:
|
||||
try:
|
||||
content_str = content_bytes.decode("utf-8")
|
||||
except:
|
||||
except Exception:
|
||||
logger.error(f"Unable to decode content of file at {url_path}. Skip indexing it")
|
||||
continue
|
||||
plaintext_files += [{"content": content_str, "path": url_path}]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import urllib3.util
|
||||
@@ -160,7 +159,7 @@ class MarkdownToEntries(TextToEntries):
|
||||
calculated_line = start_line if start_line > 0 else 1
|
||||
|
||||
# Check if raw_filename is a URL. If so, save it as is. If not, convert it to a Path.
|
||||
if type(raw_filename) == str and re.search(r"^https?://", raw_filename):
|
||||
if isinstance(raw_filename, str) and re.search(r"^https?://", raw_filename):
|
||||
# Escape the URL to avoid issues with special characters
|
||||
entry_filename = urllib3.util.parse_url(raw_filename).url
|
||||
uri = entry_filename
|
||||
|
||||
@@ -91,7 +91,7 @@ class NotionToEntries(TextToEntries):
|
||||
json=self.body_params,
|
||||
).json()
|
||||
responses.append(result)
|
||||
if result.get("has_more", False) == False:
|
||||
if not result.get("has_more", False):
|
||||
break
|
||||
else:
|
||||
self.body_params.update({"start_cursor": result["next_cursor"]})
|
||||
@@ -118,7 +118,7 @@ class NotionToEntries(TextToEntries):
|
||||
page_id = page["id"]
|
||||
title, content = self.get_page_content(page_id)
|
||||
|
||||
if title == None or content == None:
|
||||
if title is None or content is None:
|
||||
return []
|
||||
|
||||
current_entries = []
|
||||
@@ -126,11 +126,11 @@ class NotionToEntries(TextToEntries):
|
||||
for block in content.get("results", []):
|
||||
block_type = block.get("type")
|
||||
|
||||
if block_type == None:
|
||||
if block_type is None:
|
||||
continue
|
||||
block_data = block[block_type]
|
||||
|
||||
if block_data.get("rich_text") == None or len(block_data["rich_text"]) == 0:
|
||||
if block_data.get("rich_text") is None or len(block_data["rich_text"]) == 0:
|
||||
# There's no text to handle here.
|
||||
continue
|
||||
|
||||
@@ -179,7 +179,7 @@ class NotionToEntries(TextToEntries):
|
||||
results = children.get("results", [])
|
||||
for child in results:
|
||||
child_type = child.get("type")
|
||||
if child_type == None:
|
||||
if child_type is None:
|
||||
continue
|
||||
child_data = child[child_type]
|
||||
if child_data.get("rich_text") and len(child_data["rich_text"]) > 0:
|
||||
|
||||
@@ -8,7 +8,6 @@ from khoj.database.models import KhojUser
|
||||
from khoj.processor.content.org_mode import orgnode
|
||||
from khoj.processor.content.org_mode.orgnode import Orgnode
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
@@ -212,10 +211,10 @@ class OrgToEntries(TextToEntries):
|
||||
compiled += f"\t {tags_str}."
|
||||
|
||||
if parsed_entry.closed:
|
||||
compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.'
|
||||
compiled += f"\n Closed on {parsed_entry.closed.strftime('%Y-%m-%d')}."
|
||||
|
||||
if parsed_entry.scheduled:
|
||||
compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.'
|
||||
compiled += f"\n Scheduled for {parsed_entry.scheduled.strftime('%Y-%m-%d')}."
|
||||
|
||||
if parsed_entry.hasBody:
|
||||
compiled += f"\n {parsed_entry.body}"
|
||||
|
||||
@@ -65,7 +65,7 @@ def makelist(file, filename, start_line: int = 1, ancestry_lines: int = 0) -> Li
|
||||
"""
|
||||
ctr = 0
|
||||
|
||||
if type(file) == str:
|
||||
if isinstance(file, str):
|
||||
f = file.splitlines()
|
||||
else:
|
||||
f = file
|
||||
@@ -512,11 +512,11 @@ class Orgnode(object):
|
||||
if self._closed or self._scheduled or self._deadline:
|
||||
n = n + indent
|
||||
if self._closed:
|
||||
n = n + f'CLOSED: [{self._closed.strftime("%Y-%m-%d %a")}] '
|
||||
n = n + f"CLOSED: [{self._closed.strftime('%Y-%m-%d %a')}] "
|
||||
if self._scheduled:
|
||||
n = n + f'SCHEDULED: <{self._scheduled.strftime("%Y-%m-%d %a")}> '
|
||||
n = n + f"SCHEDULED: <{self._scheduled.strftime('%Y-%m-%d %a')}> "
|
||||
if self._deadline:
|
||||
n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> '
|
||||
n = n + f"DEADLINE: <{self._deadline.strftime('%Y-%m-%d %a')}> "
|
||||
if self._closed or self._scheduled or self._deadline:
|
||||
n = n + "\n"
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import urllib3
|
||||
@@ -97,7 +96,7 @@ class PlaintextToEntries(TextToEntries):
|
||||
for parsed_entry in parsed_entries:
|
||||
raw_filename = entry_to_file_map[parsed_entry]
|
||||
# Check if raw_filename is a URL. If so, save it as is. If not, convert it to a Path.
|
||||
if type(raw_filename) == str and re.search(r"^https?://", raw_filename):
|
||||
if isinstance(raw_filename, str) and re.search(r"^https?://", raw_filename):
|
||||
# Escape the URL to avoid issues with special characters
|
||||
entry_filename = urllib3.util.parse_url(raw_filename).url
|
||||
else:
|
||||
|
||||
@@ -30,8 +30,7 @@ class TextToEntries(ABC):
|
||||
self.date_filter = DateFilter()
|
||||
|
||||
@abstractmethod
|
||||
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]:
|
||||
...
|
||||
def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = False) -> Tuple[int, int]: ...
|
||||
|
||||
@staticmethod
|
||||
def hash_func(key: str) -> Callable:
|
||||
|
||||
@@ -194,7 +194,7 @@ def gemini_completion_with_backoff(
|
||||
or not response.candidates[0].content
|
||||
or response.candidates[0].content.parts is None
|
||||
):
|
||||
raise ValueError(f"Failed to get response from model.")
|
||||
raise ValueError("Failed to get response from model.")
|
||||
raw_content = [part.model_dump() for part in response.candidates[0].content.parts]
|
||||
if response.function_calls:
|
||||
function_calls = [
|
||||
@@ -212,7 +212,7 @@ def gemini_completion_with_backoff(
|
||||
response = None
|
||||
# Handle 429 rate limit errors directly
|
||||
if e.code == 429:
|
||||
response_text = f"My brain is exhausted. Can you please try again in a bit?"
|
||||
response_text = "My brain is exhausted. Can you please try again in a bit?"
|
||||
# Log the full error details for debugging
|
||||
logger.error(f"Gemini ClientError: {e.code} {e.status}. Details: {e.details}")
|
||||
# Handle other errors
|
||||
@@ -361,7 +361,7 @@ def handle_gemini_response(
|
||||
|
||||
# Ensure we have a proper list of candidates
|
||||
if not isinstance(candidates, list):
|
||||
message = f"\nUnexpected response format. Try again."
|
||||
message = "\nUnexpected response format. Try again."
|
||||
stopped = True
|
||||
return message, stopped
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from time import perf_counter
|
||||
from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
@@ -284,9 +283,9 @@ async def chat_completion_with_backoff(
|
||||
if len(system_messages) > 0:
|
||||
first_system_message_index, first_system_message = system_messages[0]
|
||||
first_system_message_content = first_system_message["content"]
|
||||
formatted_messages[first_system_message_index][
|
||||
"content"
|
||||
] = f"{first_system_message_content}\nFormatting re-enabled"
|
||||
formatted_messages[first_system_message_index]["content"] = (
|
||||
f"{first_system_message_content}\nFormatting re-enabled"
|
||||
)
|
||||
elif is_twitter_reasoning_model(model_name, api_base_url):
|
||||
reasoning_effort = "high" if deepthought else "low"
|
||||
# Grok-4 models do not support reasoning_effort parameter
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
@@ -18,7 +17,7 @@ import requests
|
||||
import tiktoken
|
||||
import yaml
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError, create_model
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
@@ -47,7 +46,11 @@ from khoj.utils.yaml import yaml_dump
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from git import Repo
|
||||
import importlib.util
|
||||
|
||||
git_spec = importlib.util.find_spec("git")
|
||||
if git_spec is None:
|
||||
raise ImportError
|
||||
except ImportError:
|
||||
if is_promptrace_enabled():
|
||||
logger.warning("GitPython not installed. `pip install gitpython` to use prompt tracer.")
|
||||
@@ -294,7 +297,7 @@ def construct_chat_history_for_operator(conversation_history: List[ChatMessageMo
|
||||
if chat.by == "you" and chat.message:
|
||||
content = [{"type": "text", "text": chat.message}]
|
||||
for file in chat.queryFiles or []:
|
||||
content += [{"type": "text", "text": f'## File: {file["name"]}\n\n{file["content"]}'}]
|
||||
content += [{"type": "text", "text": f"## File: {file['name']}\n\n{file['content']}"}]
|
||||
user_message = AgentMessage(role="user", content=content)
|
||||
elif chat.by == "khoj" and chat.message:
|
||||
chat_history += [user_message, AgentMessage(role="assistant", content=chat.message)]
|
||||
@@ -311,7 +314,10 @@ def construct_tool_chat_history(
|
||||
If no tool is provided inferred query for all tools used are added.
|
||||
"""
|
||||
chat_history: list = []
|
||||
base_extractor: Callable[[ResearchIteration], List[str]] = lambda iteration: []
|
||||
|
||||
def base_extractor(iteration: ResearchIteration) -> List[str]:
|
||||
return []
|
||||
|
||||
extract_inferred_query_map: Dict[ConversationCommand, Callable[[ResearchIteration], List[str]]] = {
|
||||
ConversationCommand.SemanticSearchFiles: (
|
||||
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
|
||||
@@ -498,7 +504,7 @@ async def save_to_conversation_log(
|
||||
|
||||
logger.info(
|
||||
f"""
|
||||
Saved Conversation Turn ({db_conversation.id if db_conversation else 'N/A'}):
|
||||
Saved Conversation Turn ({db_conversation.id if db_conversation else "N/A"}):
|
||||
You ({user.username}): "{q}"
|
||||
|
||||
Khoj: "{chat_response}"
|
||||
@@ -625,7 +631,7 @@ def generate_chatml_messages_with_context(
|
||||
|
||||
if not is_none_or_empty(chat.operatorContext):
|
||||
operator_context = chat.operatorContext
|
||||
operator_content = "\n\n".join([f'## Task: {oc["query"]}\n{oc["response"]}\n' for oc in operator_context])
|
||||
operator_content = "\n\n".join([f"## Task: {oc['query']}\n{oc['response']}\n" for oc in operator_context])
|
||||
message_context += [
|
||||
{
|
||||
"type": "text",
|
||||
@@ -744,7 +750,7 @@ def get_encoder(
|
||||
else:
|
||||
# as tiktoken doesn't recognize o1 model series yet
|
||||
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
|
||||
except:
|
||||
except Exception:
|
||||
encoder = tiktoken.encoding_for_model(default_tokenizer)
|
||||
if state.verbose > 2:
|
||||
logger.debug(
|
||||
@@ -846,9 +852,9 @@ def truncate_messages(
|
||||
total_tokens, _ = count_total_tokens(messages, encoder, system_message)
|
||||
if total_tokens > max_prompt_size:
|
||||
# At this point, a single message with a single content part of type dict should remain
|
||||
assert (
|
||||
len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict)
|
||||
), "Expected a single message with a single content part remaining at this point in truncation"
|
||||
assert len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict), (
|
||||
"Expected a single message with a single content part remaining at this point in truncation"
|
||||
)
|
||||
|
||||
# Collate message content into single string to ease truncation
|
||||
part = messages[0].content[0]
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import openai
|
||||
import requests
|
||||
import tqdm
|
||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||
|
||||
@@ -108,12 +108,12 @@ async def text_to_image(
|
||||
if "content_policy_violation" in e.message:
|
||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||
status_code = e.status_code # type: ignore
|
||||
message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore
|
||||
message = "Image generation blocked by OpenAI due to policy violation" # type: ignore
|
||||
yield image_url or image, status_code, message
|
||||
return
|
||||
else:
|
||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||
message = f"Image generation failed using OpenAI" # type: ignore
|
||||
message = "Image generation failed using OpenAI" # type: ignore
|
||||
status_code = e.status_code # type: ignore
|
||||
yield image_url or image, status_code, message
|
||||
return
|
||||
@@ -199,7 +199,7 @@ def generate_image_with_stability(
|
||||
|
||||
# Call Stability AI API to generate image
|
||||
response = requests.post(
|
||||
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
|
||||
"https://api.stability.ai/v2beta/stable-image/generate/sd3",
|
||||
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
|
||||
files={"none": ""},
|
||||
data={
|
||||
|
||||
@@ -11,7 +11,7 @@ from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
construct_chat_history_for_operator,
|
||||
)
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_actions import RequestUserAction
|
||||
from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent
|
||||
from khoj.processor.operator.operator_agent_base import OperatorAgent
|
||||
from khoj.processor.operator.operator_agent_binary import BinaryOperatorAgent
|
||||
@@ -59,7 +59,7 @@ async def operate_environment(
|
||||
if not reasoning_model or not reasoning_model.vision_enabled:
|
||||
reasoning_model = await ConversationAdapters.aget_vision_enabled_config()
|
||||
if not reasoning_model:
|
||||
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.")
|
||||
raise ValueError("No vision enabled chat model found. Configure a vision chat model to operate environment.")
|
||||
|
||||
# Create conversation history from conversation log
|
||||
chat_history = construct_chat_history_for_operator(conversation_log)
|
||||
|
||||
@@ -1,14 +1,27 @@
|
||||
import json
|
||||
import logging
|
||||
from textwrap import dedent
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.utils import construct_structured_message
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_agent_base import AgentActResult
|
||||
from khoj.processor.operator.operator_actions import (
|
||||
BackAction,
|
||||
ClickAction,
|
||||
DoubleClickAction,
|
||||
DragAction,
|
||||
GotoAction,
|
||||
KeypressAction,
|
||||
OperatorAction,
|
||||
Point,
|
||||
ScreenshotAction,
|
||||
ScrollAction,
|
||||
TypeAction,
|
||||
WaitAction,
|
||||
)
|
||||
from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState
|
||||
from khoj.utils.helpers import get_chat_usage_metrics
|
||||
|
||||
|
||||
@@ -18,7 +18,22 @@ from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
from PIL import Image
|
||||
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_actions import (
|
||||
BackAction,
|
||||
ClickAction,
|
||||
DoubleClickAction,
|
||||
DragAction,
|
||||
GotoAction,
|
||||
KeyDownAction,
|
||||
KeypressAction,
|
||||
KeyUpAction,
|
||||
MoveAction,
|
||||
OperatorAction,
|
||||
RequestUserAction,
|
||||
ScrollAction,
|
||||
TypeAction,
|
||||
WaitAction,
|
||||
)
|
||||
from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState
|
||||
from khoj.utils.helpers import get_chat_usage_metrics
|
||||
|
||||
@@ -122,11 +137,10 @@ class GroundingAgentUitars:
|
||||
)
|
||||
|
||||
temperature = self.temperature
|
||||
top_k = self.top_k
|
||||
try_times = 3
|
||||
while not parsed_responses:
|
||||
if try_times <= 0:
|
||||
logger.warning(f"Reach max retry times to fetch response from client, as error flag.")
|
||||
logger.warning("Reach max retry times to fetch response from client, as error flag.")
|
||||
return "client error\nFAIL", []
|
||||
try:
|
||||
message_content = "\n".join([msg["content"][0].get("text") or "[image]" for msg in messages])
|
||||
@@ -163,7 +177,6 @@ class GroundingAgentUitars:
|
||||
prediction = None
|
||||
try_times -= 1
|
||||
temperature = 1
|
||||
top_k = -1
|
||||
|
||||
if prediction is None:
|
||||
return "client error\nFAIL", []
|
||||
@@ -264,9 +277,9 @@ class GroundingAgentUitars:
|
||||
raise ValueError(f"Unsupported environment type: {environment_type}")
|
||||
|
||||
def _format_messages_for_api(self, instruction: str, current_state: EnvState):
|
||||
assert len(self.observations) == len(self.actions) and len(self.actions) == len(
|
||||
self.thoughts
|
||||
), "The number of observations and actions should be the same."
|
||||
assert len(self.observations) == len(self.actions) and len(self.actions) == len(self.thoughts), (
|
||||
"The number of observations and actions should be the same."
|
||||
)
|
||||
|
||||
self.history_images.append(base64.b64decode(current_state.screenshot))
|
||||
self.observations.append({"screenshot": current_state.screenshot, "accessibility_tree": None})
|
||||
@@ -524,7 +537,7 @@ class GroundingAgentUitars:
|
||||
parsed_actions = [self.parse_action_string(action.replace("\n", "\\n").lstrip()) for action in all_action]
|
||||
actions: list[dict] = []
|
||||
for action_instance, raw_str in zip(parsed_actions, all_action):
|
||||
if action_instance == None:
|
||||
if action_instance is None:
|
||||
print(f"Action can't parse: {raw_str}")
|
||||
raise ValueError(f"Action can't parse: {raw_str}")
|
||||
action_type = action_instance["function"]
|
||||
@@ -756,7 +769,7 @@ class GroundingAgentUitars:
|
||||
The pyautogui code string
|
||||
"""
|
||||
|
||||
pyautogui_code = f"import pyautogui\nimport time\n"
|
||||
pyautogui_code = "import pyautogui\nimport time\n"
|
||||
actions = []
|
||||
if isinstance(responses, dict):
|
||||
responses = [responses]
|
||||
@@ -774,7 +787,7 @@ class GroundingAgentUitars:
|
||||
if response_id == 0:
|
||||
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
|
||||
else:
|
||||
pyautogui_code += f"\ntime.sleep(1)\n"
|
||||
pyautogui_code += "\ntime.sleep(1)\n"
|
||||
|
||||
action_dict = response
|
||||
action_type = action_dict.get("action_type")
|
||||
@@ -846,17 +859,17 @@ class GroundingAgentUitars:
|
||||
if content:
|
||||
if input_swap:
|
||||
actions += TypeAction()
|
||||
pyautogui_code += f"\nimport pyperclip"
|
||||
pyautogui_code += "\nimport pyperclip"
|
||||
pyautogui_code += f"\npyperclip.copy('{stripped_content}')"
|
||||
pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')"
|
||||
pyautogui_code += f"\ntime.sleep(0.5)\n"
|
||||
pyautogui_code += "\npyautogui.hotkey('ctrl', 'v')"
|
||||
pyautogui_code += "\ntime.sleep(0.5)\n"
|
||||
if content.endswith("\n") or content.endswith("\\n"):
|
||||
pyautogui_code += f"\npyautogui.press('enter')"
|
||||
pyautogui_code += "\npyautogui.press('enter')"
|
||||
else:
|
||||
pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)"
|
||||
pyautogui_code += f"\ntime.sleep(0.5)\n"
|
||||
pyautogui_code += "\ntime.sleep(0.5)\n"
|
||||
if content.endswith("\n") or content.endswith("\\n"):
|
||||
pyautogui_code += f"\npyautogui.press('enter')"
|
||||
pyautogui_code += "\npyautogui.press('enter')"
|
||||
|
||||
elif action_type in ["drag", "select"]:
|
||||
# Parsing drag or select action based on start and end_boxes
|
||||
@@ -869,9 +882,7 @@ class GroundingAgentUitars:
|
||||
x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2]
|
||||
ex = round(float((x1 + x2) / 2) * image_width, 3)
|
||||
ey = round(float((y1 + y2) / 2) * image_height, 3)
|
||||
pyautogui_code += (
|
||||
f"\npyautogui.moveTo({sx}, {sy})\n" f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
|
||||
)
|
||||
pyautogui_code += f"\npyautogui.moveTo({sx}, {sy})\n\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
|
||||
|
||||
elif action_type == "scroll":
|
||||
# Parsing scroll action
|
||||
@@ -888,11 +899,11 @@ class GroundingAgentUitars:
|
||||
y = None
|
||||
direction = action_inputs.get("direction", "")
|
||||
|
||||
if x == None:
|
||||
if x is None:
|
||||
if "up" in direction.lower():
|
||||
pyautogui_code += f"\npyautogui.scroll(5)"
|
||||
pyautogui_code += "\npyautogui.scroll(5)"
|
||||
elif "down" in direction.lower():
|
||||
pyautogui_code += f"\npyautogui.scroll(-5)"
|
||||
pyautogui_code += "\npyautogui.scroll(-5)"
|
||||
else:
|
||||
if "up" in direction.lower():
|
||||
pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})"
|
||||
@@ -923,7 +934,7 @@ class GroundingAgentUitars:
|
||||
pyautogui_code += f"\npyautogui.moveTo({x}, {y})"
|
||||
|
||||
elif action_type in ["finished"]:
|
||||
pyautogui_code = f"DONE"
|
||||
pyautogui_code = "DONE"
|
||||
|
||||
else:
|
||||
pyautogui_code += f"\n# Unrecognized action type: {action_type}"
|
||||
|
||||
@@ -11,7 +11,32 @@ from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlo
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.anthropic.utils import is_reasoning_model
|
||||
from khoj.processor.conversation.utils import AgentMessage
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_actions import (
|
||||
BackAction,
|
||||
ClickAction,
|
||||
CursorPositionAction,
|
||||
DoubleClickAction,
|
||||
DragAction,
|
||||
GotoAction,
|
||||
HoldKeyAction,
|
||||
KeypressAction,
|
||||
MouseDownAction,
|
||||
MouseUpAction,
|
||||
MoveAction,
|
||||
NoopAction,
|
||||
OperatorAction,
|
||||
Point,
|
||||
ScreenshotAction,
|
||||
ScrollAction,
|
||||
TerminalAction,
|
||||
TextEditorCreateAction,
|
||||
TextEditorInsertAction,
|
||||
TextEditorStrReplaceAction,
|
||||
TextEditorViewAction,
|
||||
TripleClickAction,
|
||||
TypeAction,
|
||||
WaitAction,
|
||||
)
|
||||
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
EnvironmentType,
|
||||
@@ -518,7 +543,7 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
def model_default_headers(self) -> list[str]:
|
||||
"""Get the default computer use headers for the given model."""
|
||||
if self.vision_model.name.startswith("claude-3-7-sonnet"):
|
||||
return [f"computer-use-2025-01-24", "token-efficient-tools-2025-02-19"]
|
||||
return ["computer-use-2025-01-24", "token-efficient-tools-2025-02-19"]
|
||||
elif self.vision_model.name.startswith("claude-sonnet-4") or self.vision_model.name.startswith("claude-opus-4"):
|
||||
return ["computer-use-2025-01-24"]
|
||||
else:
|
||||
@@ -538,7 +563,7 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
|
||||
* The current URL is {current_state.url}.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
@@ -563,7 +588,7 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<CONTEXT>
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
|
||||
</CONTEXT>
|
||||
"""
|
||||
).lstrip()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from khoj.processor.conversation.utils import (
|
||||
)
|
||||
from khoj.processor.operator.grounding_agent import GroundingAgent
|
||||
from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_actions import OperatorAction, WaitAction
|
||||
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
EnvironmentType,
|
||||
@@ -181,7 +181,7 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
elif action.type == "key_down":
|
||||
rendered_parts += [f'**Action**: Press Key "{action.key}"']
|
||||
elif action.type == "screenshot" and not current_state.screenshot:
|
||||
rendered_parts += [f"**Error**: Failed to take screenshot"]
|
||||
rendered_parts += ["**Error**: Failed to take screenshot"]
|
||||
elif action.type == "goto":
|
||||
rendered_parts += [f"**Action**: Open URL {action.url}"]
|
||||
else:
|
||||
@@ -317,7 +317,7 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
# Introduction
|
||||
* You are Khoj, a smart and resourceful web browsing assistant. You help the user accomplish their task using a web browser.
|
||||
* You are given the user's query and screenshots of the browser's state transitions.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
|
||||
* The current URL is {env_state.url}.
|
||||
|
||||
# Your Task
|
||||
@@ -362,7 +362,7 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
# Introduction
|
||||
* You are Khoj, a smart and resourceful computer assistant. You help the user accomplish their task using a computer.
|
||||
* You are given the user's query and screenshots of the computer's state transitions.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
|
||||
|
||||
# Your Task
|
||||
* First look at the screenshots carefully to notice all pertinent information.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import platform
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from textwrap import dedent
|
||||
@@ -10,7 +9,23 @@ from openai.types.responses import Response, ResponseOutputItem
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.utils import AgentMessage
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_actions import (
|
||||
BackAction,
|
||||
ClickAction,
|
||||
DoubleClickAction,
|
||||
DragAction,
|
||||
GotoAction,
|
||||
KeypressAction,
|
||||
MoveAction,
|
||||
NoopAction,
|
||||
OperatorAction,
|
||||
Point,
|
||||
RequestUserAction,
|
||||
ScreenshotAction,
|
||||
ScrollAction,
|
||||
TypeAction,
|
||||
WaitAction,
|
||||
)
|
||||
from khoj.processor.operator.operator_agent_base import AgentActResult, OperatorAgent
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
EnvironmentType,
|
||||
@@ -152,7 +167,7 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
# Add screenshot data in openai message format
|
||||
action_result["output"] = {
|
||||
"type": "input_image",
|
||||
"image_url": f'data:image/webp;base64,{result_content["image"]}',
|
||||
"image_url": f"data:image/webp;base64,{result_content['image']}",
|
||||
"current_url": result_content["url"],
|
||||
}
|
||||
elif action_result["type"] == "computer_call_output" and idx == len(env_steps) - 1:
|
||||
@@ -311,7 +326,7 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
elif block.type == "function_call":
|
||||
if block.name == "goto":
|
||||
args = json.loads(block.arguments)
|
||||
render_texts = [f'Open URL: {args.get("url", "[Missing URL]")}']
|
||||
render_texts = [f"Open URL: {args.get('url', '[Missing URL]')}"]
|
||||
else:
|
||||
render_texts += [block.name]
|
||||
elif block.type == "computer_call":
|
||||
@@ -351,7 +366,7 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
|
||||
* The current URL is {current_state.url}.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
@@ -374,7 +389,7 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<CONTEXT>
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current date is {datetime.today().strftime("%A, %B %-d, %Y")}.
|
||||
</CONTEXT>
|
||||
"""
|
||||
).lstrip()
|
||||
|
||||
@@ -247,7 +247,7 @@ class BrowserEnvironment(Environment):
|
||||
|
||||
case "drag":
|
||||
if not isinstance(action, DragAction):
|
||||
raise TypeError(f"Invalid action type for drag")
|
||||
raise TypeError("Invalid action type for drag")
|
||||
path = action.path
|
||||
if not path:
|
||||
error = "Missing path for drag action"
|
||||
|
||||
@@ -532,7 +532,7 @@ class ComputerEnvironment(Environment):
|
||||
else:
|
||||
return {"success": False, "output": process.stdout, "error": process.stderr}
|
||||
except asyncio.TimeoutError:
|
||||
return {"success": False, "output": "", "error": f"Command timed out after 120 seconds."}
|
||||
return {"success": False, "output": "", "error": "Command timed out after 120 seconds."}
|
||||
except Exception as e:
|
||||
return {"success": False, "output": "", "error": str(e)}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json # Used for working with JSON data
|
||||
import os
|
||||
|
||||
import requests # Used for making HTTP requests
|
||||
|
||||
@@ -385,7 +385,7 @@ async def read_webpages(
|
||||
tracer: dict = {},
|
||||
):
|
||||
"Infer web pages to read from the query and extract relevant information from them"
|
||||
logger.info(f"Inferring web pages to read")
|
||||
logger.info("Inferring web pages to read")
|
||||
urls = await infer_webpage_urls(
|
||||
query,
|
||||
max_webpages_to_read,
|
||||
|
||||
@@ -93,7 +93,7 @@ async def run_code(
|
||||
|
||||
# Run Code
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Running code snippet**"):
|
||||
async for event in send_status_func("**Running code snippet**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
try:
|
||||
with timer("Chat actor: Execute generated program", logger, log_level=logging.INFO):
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import List, Optional, Union
|
||||
|
||||
import openai
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response
|
||||
from starlette.authentication import has_required_scope, requires
|
||||
|
||||
@@ -94,7 +93,7 @@ def update(
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
else:
|
||||
logger.info(f"📪 Server indexed content updated via API")
|
||||
logger.info("📪 Server indexed content updated via API")
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
||||
@@ -6,12 +6,11 @@ from typing import Dict, List, Optional
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
from starlette.authentication import has_required_scope, requires
|
||||
|
||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
|
||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters
|
||||
from khoj.database.models import Agent, Conversation, KhojUser, PriceTier
|
||||
from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
|
||||
from khoj.utils.helpers import (
|
||||
|
||||
@@ -109,7 +109,7 @@ def post_automation(
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True)
|
||||
return Response(
|
||||
content=f"Unable to create automation. Ensure the automation doesn't already exist.",
|
||||
content="Unable to create automation. Ensure the automation doesn't already exist.",
|
||||
media_type="text/plain",
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
@@ -32,10 +31,10 @@ from khoj.database.adapters import (
|
||||
PublicConversationAdapters,
|
||||
aget_user_name,
|
||||
)
|
||||
from khoj.database.models import Agent, ChatMessageModel, KhojUser
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.openai.utils import is_local_api
|
||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||
from khoj.processor.conversation.prompts import no_entries_found
|
||||
from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
ResponseWithThought,
|
||||
@@ -65,11 +64,8 @@ from khoj.routers.helpers import (
|
||||
acreate_title_from_history,
|
||||
agenerate_chat_response,
|
||||
aget_data_sources_and_output_format,
|
||||
construct_automation_created_message,
|
||||
create_automation,
|
||||
gather_raw_query_files,
|
||||
generate_mermaidjs_diagram,
|
||||
generate_summary_from_files,
|
||||
get_conversation_command,
|
||||
get_message_from_queue,
|
||||
is_query_empty,
|
||||
@@ -89,13 +85,11 @@ from khoj.utils.helpers import (
|
||||
convert_image_to_webp,
|
||||
get_country_code_from_timezone,
|
||||
get_country_name_from_timezone,
|
||||
get_device,
|
||||
is_env_var_true,
|
||||
is_none_or_empty,
|
||||
is_operator_enabled,
|
||||
)
|
||||
from khoj.utils.rawconfig import (
|
||||
ChatRequestBody,
|
||||
FileAttachment,
|
||||
FileFilterRequest,
|
||||
FilesFilterRequest,
|
||||
@@ -689,7 +683,6 @@ async def event_generator(
|
||||
region = body.region
|
||||
country = body.country or get_country_name_from_timezone(body.timezone)
|
||||
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
||||
timezone = body.timezone
|
||||
raw_images = body.images
|
||||
raw_query_files = body.files
|
||||
|
||||
@@ -853,7 +846,8 @@ async def event_generator(
|
||||
if (
|
||||
len(train_of_thought) > 0
|
||||
and train_of_thought[-1]["type"] == ChatEvent.THOUGHT.value
|
||||
and type(train_of_thought[-1]["data"]) == type(data) == str
|
||||
and isinstance(train_of_thought[-1]["data"], str)
|
||||
and isinstance(data, str)
|
||||
):
|
||||
train_of_thought[-1]["data"] += data
|
||||
else:
|
||||
@@ -1075,11 +1069,11 @@ async def event_generator(
|
||||
|
||||
# researched_results = await extract_relevant_info(q, researched_results, agent)
|
||||
if state.verbose > 1:
|
||||
logger.debug(f'Researched Results: {"".join(r.summarizedResult or "" for r in research_results)}')
|
||||
logger.debug(f"Researched Results: {''.join(r.summarizedResult or '' for r in research_results)}")
|
||||
|
||||
# Gather Context
|
||||
## Extract Document References
|
||||
if not ConversationCommand.Research in conversation_commands:
|
||||
if ConversationCommand.Research not in conversation_commands:
|
||||
try:
|
||||
async for result in search_documents(
|
||||
q,
|
||||
@@ -1218,7 +1212,7 @@ async def event_generator(
|
||||
else:
|
||||
code_results = result
|
||||
except ValueError as e:
|
||||
program_execution_context.append(f"Failed to run code")
|
||||
program_execution_context.append("Failed to run code")
|
||||
logger.warning(
|
||||
f"Failed to use code tool: {e}. Attempting to respond without code results",
|
||||
exc_info=True,
|
||||
@@ -1297,7 +1291,7 @@ async def event_generator(
|
||||
inferred_queries.append(improved_image_prompt)
|
||||
if generated_image is None or status_code != 200:
|
||||
program_execution_context.append(f"Failed to generate image with {improved_image_prompt}")
|
||||
async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"):
|
||||
async for result in send_event(ChatEvent.STATUS, "Failed to generate image"):
|
||||
yield result
|
||||
else:
|
||||
generated_images.append(generated_image)
|
||||
@@ -1315,7 +1309,7 @@ async def event_generator(
|
||||
yield result
|
||||
|
||||
if ConversationCommand.Diagram in conversation_commands:
|
||||
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
|
||||
async for result in send_event(ChatEvent.STATUS, "Creating diagram"):
|
||||
yield result
|
||||
|
||||
inferred_queries = []
|
||||
@@ -1372,7 +1366,7 @@ async def event_generator(
|
||||
return
|
||||
|
||||
## Generate Text Output
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
|
||||
async for result in send_event(ChatEvent.STATUS, "**Generating a well-informed response**"):
|
||||
yield result
|
||||
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
|
||||
@@ -3,7 +3,6 @@ import logging
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response
|
||||
from starlette.authentication import has_required_scope, requires
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ async def subscribe(request: Request):
|
||||
)
|
||||
logger.log(logging.INFO, f"🥳 New User Created: {user.user.uuid}")
|
||||
|
||||
logger.info(f'Stripe subscription {event["type"]} for {customer_email}')
|
||||
logger.info(f"Stripe subscription {event['type']} for {customer_email}")
|
||||
return {"success": success}
|
||||
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ async def send_magic_link_email(email, unique_id, host):
|
||||
{
|
||||
"sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"),
|
||||
"to": email,
|
||||
"subject": f"Your login code to Khoj",
|
||||
"subject": "Your login code to Khoj",
|
||||
"html": html_content,
|
||||
}
|
||||
)
|
||||
@@ -98,11 +98,11 @@ async def send_query_feedback(uquery, kquery, sentiment, user_email):
|
||||
user_email=user_email if not is_none_or_empty(user_email) else "N/A",
|
||||
)
|
||||
# send feedback to fixed account
|
||||
r = resend.Emails.send(
|
||||
resend.Emails.send(
|
||||
{
|
||||
"sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"),
|
||||
"to": "team@khoj.dev",
|
||||
"subject": f"User Feedback",
|
||||
"subject": "User Feedback",
|
||||
"html": html_content,
|
||||
}
|
||||
)
|
||||
@@ -127,7 +127,7 @@ def send_task_email(name, email, query, result, subject, is_image=False):
|
||||
|
||||
r = resend.Emails.send(
|
||||
{
|
||||
"sender": f'Khoj <{os.environ.get("RESEND_EMAIL", "khoj@khoj.dev")}>',
|
||||
"sender": f"Khoj <{os.environ.get('RESEND_EMAIL', 'khoj@khoj.dev')}>",
|
||||
"to": email,
|
||||
"subject": f"✨ {subject}",
|
||||
"html": html_content,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import fnmatch
|
||||
import hashlib
|
||||
import json
|
||||
@@ -47,14 +46,12 @@ from khoj.database.adapters import (
|
||||
EntryAdapters,
|
||||
FileObjectAdapters,
|
||||
aget_user_by_email,
|
||||
ais_user_subscribed,
|
||||
create_khoj_token,
|
||||
get_default_search_model,
|
||||
get_khoj_tokens,
|
||||
get_user_name,
|
||||
get_user_notion_config,
|
||||
get_user_subscription_state,
|
||||
is_user_subscribed,
|
||||
run_with_process_lock,
|
||||
)
|
||||
from khoj.database.models import (
|
||||
@@ -160,7 +157,7 @@ def validate_chat_model(user: KhojUser):
|
||||
|
||||
async def is_ready_to_chat(user: KhojUser):
|
||||
user_chat_model = await ConversationAdapters.aget_user_chat_model(user)
|
||||
if user_chat_model == None:
|
||||
if user_chat_model is None:
|
||||
user_chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
||||
|
||||
if (
|
||||
@@ -581,7 +578,7 @@ async def generate_online_subqueries(
|
||||
)
|
||||
return {q}
|
||||
return response
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.error(f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}")
|
||||
return {q}
|
||||
|
||||
@@ -1172,8 +1169,8 @@ async def search_documents(
|
||||
agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent)
|
||||
|
||||
if (
|
||||
not ConversationCommand.Notes in conversation_commands
|
||||
and not ConversationCommand.Default in conversation_commands
|
||||
ConversationCommand.Notes not in conversation_commands
|
||||
and ConversationCommand.Default not in conversation_commands
|
||||
and not agent_has_entries
|
||||
):
|
||||
yield compiled_references, inferred_queries, q
|
||||
@@ -1325,8 +1322,8 @@ async def extract_questions(
|
||||
logger.error(f"Invalid response for constructing subqueries: {response}")
|
||||
return [query]
|
||||
return queries
|
||||
except:
|
||||
logger.warning(f"LLM returned invalid JSON. Falling back to using user message as search query.")
|
||||
except Exception:
|
||||
logger.warning("LLM returned invalid JSON. Falling back to using user message as search query.")
|
||||
return [query]
|
||||
|
||||
|
||||
@@ -1351,7 +1348,7 @@ async def execute_search(
|
||||
return results
|
||||
|
||||
if q is None or q == "":
|
||||
logger.warning(f"No query param (q) passed in API call to initiate search")
|
||||
logger.warning("No query param (q) passed in API call to initiate search")
|
||||
return results
|
||||
|
||||
# initialize variables
|
||||
@@ -1364,7 +1361,7 @@ async def execute_search(
|
||||
if user:
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
|
||||
if query_cache_key in state.query_cache[user.uuid]:
|
||||
logger.debug(f"Return response from query cache")
|
||||
logger.debug("Return response from query cache")
|
||||
return state.query_cache[user.uuid][query_cache_key]
|
||||
|
||||
# Encode query with filter terms removed
|
||||
@@ -1875,8 +1872,8 @@ class ApiUserRateLimiter:
|
||||
|
||||
user: KhojUser = websocket.scope["user"].object
|
||||
subscribed = has_required_scope(websocket, ["premium"])
|
||||
current_window = "today" if self.window == 60 * 60 * 24 else f"now"
|
||||
next_window = "tomorrow" if self.window == 60 * 60 * 24 else f"in a bit"
|
||||
current_window = "today" if self.window == 60 * 60 * 24 else "now"
|
||||
next_window = "tomorrow" if self.window == 60 * 60 * 24 else "in a bit"
|
||||
common_message_prefix = f"I'm glad you're enjoying interacting with me! You've unfortunately exceeded your usage limit for {current_window}."
|
||||
|
||||
# Remove requests outside of the time window
|
||||
@@ -2219,7 +2216,7 @@ def should_notify(original_query: str, executed_query: str, ai_response: str, us
|
||||
should_notify_result = response["decision"] == "Yes"
|
||||
reason = response.get("reason", "unknown")
|
||||
logger.info(
|
||||
f'Decided to {"not " if not should_notify_result else ""}notify user of automation response because of reason: {reason}.'
|
||||
f"Decided to {'not ' if not should_notify_result else ''}notify user of automation response because of reason: {reason}."
|
||||
)
|
||||
return should_notify_result
|
||||
except Exception as e:
|
||||
@@ -2313,7 +2310,7 @@ def scheduled_chat(
|
||||
response_map = raw_response.json()
|
||||
ai_response = response_map.get("response") or response_map.get("image")
|
||||
is_image = False
|
||||
if type(ai_response) == dict:
|
||||
if isinstance(ai_response, dict):
|
||||
is_image = ai_response.get("image") is not None
|
||||
else:
|
||||
ai_response = raw_response.text
|
||||
@@ -2460,12 +2457,12 @@ async def aschedule_automation(
|
||||
|
||||
def construct_automation_created_message(automation: Job, crontime: str, query_to_run: str, subject: str):
|
||||
# Display next run time in user timezone instead of UTC
|
||||
schedule = f'{cron_descriptor.get_description(crontime)} {automation.next_run_time.strftime("%Z")}'
|
||||
schedule = f"{cron_descriptor.get_description(crontime)} {automation.next_run_time.strftime('%Z')}"
|
||||
next_run_time = automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z")
|
||||
# Remove /automated_task prefix from inferred_query
|
||||
unprefixed_query_to_run = re.sub(r"^\/automated_task\s*", "", query_to_run)
|
||||
# Create the automation response
|
||||
automation_icon_url = f"/static/assets/icons/automation.svg"
|
||||
automation_icon_url = "/static/assets/icons/automation.svg"
|
||||
return f"""
|
||||
###  Created Automation
|
||||
- Subject: **{subject}**
|
||||
@@ -2713,13 +2710,13 @@ def configure_content(
|
||||
t: Optional[state.SearchType] = state.SearchType.All,
|
||||
) -> bool:
|
||||
success = True
|
||||
if t == None:
|
||||
if t is None:
|
||||
t = state.SearchType.All
|
||||
|
||||
if t is not None and t in [type.value for type in state.SearchType]:
|
||||
t = state.SearchType(t)
|
||||
|
||||
if t is not None and not t.value in [type.value for type in state.SearchType]:
|
||||
if t is not None and t.value not in [type.value for type in state.SearchType]:
|
||||
logger.warning(f"🚨 Invalid search type: {t}")
|
||||
return False
|
||||
|
||||
@@ -2988,7 +2985,7 @@ async def grep_files(
|
||||
query += f" {' and '.join(context_info)}"
|
||||
if line_count > max_results:
|
||||
if lines_before or lines_after:
|
||||
query += f" for"
|
||||
query += " for"
|
||||
query += f" first {max_results} results"
|
||||
return query
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from khoj.processor.conversation.utils import (
|
||||
ResearchIteration,
|
||||
ToolCall,
|
||||
construct_iteration_history,
|
||||
construct_structured_message,
|
||||
construct_tool_chat_history,
|
||||
load_complex_json,
|
||||
)
|
||||
@@ -24,7 +23,6 @@ from khoj.processor.tools.online_search import read_webpages_content, search_onl
|
||||
from khoj.processor.tools.run_code import run_code
|
||||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
generate_summary_from_files,
|
||||
get_message_from_queue,
|
||||
grep_files,
|
||||
list_files,
|
||||
@@ -184,7 +182,7 @@ async def apick_next_tool(
|
||||
# TODO: Handle multiple tool calls.
|
||||
response_text = response.text
|
||||
parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0]
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# Otherwise assume the model has decided to end the research run and respond to the user.
|
||||
parsed_response = ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None)
|
||||
|
||||
@@ -199,7 +197,7 @@ async def apick_next_tool(
|
||||
if i.warning is None and isinstance(i.query, ToolCall)
|
||||
}
|
||||
if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations:
|
||||
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
|
||||
warning = "Repeated tool, query combination detected. Skipping iteration. Try something different."
|
||||
# Only send client status updates if we'll execute this iteration and model has thoughts to share.
|
||||
elif send_status_func and not is_none_or_empty(response.thought):
|
||||
async for event in send_status_func(response.thought):
|
||||
|
||||
@@ -4,12 +4,10 @@ from typing import List
|
||||
|
||||
class BaseFilter(ABC):
|
||||
@abstractmethod
|
||||
def get_filter_terms(self, query: str) -> List[str]:
|
||||
...
|
||||
def get_filter_terms(self, query: str) -> List[str]: ...
|
||||
|
||||
def can_filter(self, raw_query: str) -> bool:
|
||||
return len(self.get_filter_terms(raw_query)) > 0
|
||||
|
||||
@abstractmethod
|
||||
def defilter(self, query: str) -> str:
|
||||
...
|
||||
def defilter(self, query: str) -> str: ...
|
||||
|
||||
@@ -9,9 +9,8 @@ from asgiref.sync import sync_to_async
|
||||
from sentence_transformers import util
|
||||
|
||||
from khoj.database.adapters import EntryAdapters, get_default_search_model
|
||||
from khoj.database.models import Agent
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import get_absolute_path, timer
|
||||
|
||||
@@ -77,7 +77,7 @@ class AsyncIteratorWrapper:
|
||||
|
||||
|
||||
def is_none_or_empty(item):
|
||||
return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
|
||||
return item is None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
|
||||
|
||||
|
||||
def to_snake_case_from_dash(item: str):
|
||||
@@ -97,7 +97,7 @@ def get_from_dict(dictionary, *args):
|
||||
Returns: dictionary[args[0]][args[1]]... or None if any keys missing"""
|
||||
current = dictionary
|
||||
for arg in args:
|
||||
if not hasattr(current, "__iter__") or not arg in current:
|
||||
if not hasattr(current, "__iter__") or arg not in current:
|
||||
return None
|
||||
current = current[arg]
|
||||
return current
|
||||
@@ -751,7 +751,7 @@ def is_valid_url(url: str) -> bool:
|
||||
try:
|
||||
result = urlparse(url.strip())
|
||||
return all([result.scheme, result.netloc])
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@@ -759,7 +759,7 @@ def is_internet_connected():
|
||||
try:
|
||||
response = requests.head("https://www.google.com")
|
||||
return response.status_code == 200
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -60,9 +60,7 @@ def initialization(interactive: bool = True):
|
||||
]
|
||||
default_chat_models = known_available_models + other_available_models
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"⚠️ Failed to fetch {provider} chat models. Fallback to default models. Error: {str(e)}"
|
||||
)
|
||||
logger.warning(f"⚠️ Failed to fetch {provider} chat models. Fallback to default models. Error: {str(e)}")
|
||||
|
||||
# Set up OpenAI's online chat models
|
||||
openai_configured, openai_provider = _setup_chat_model_provider(
|
||||
|
||||
@@ -8,12 +8,10 @@ from tqdm import trange
|
||||
|
||||
class BaseEncoder(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, model_name: str, device: torch.device = None, **kwargs):
|
||||
...
|
||||
def __init__(self, model_name: str, device: torch.device = None, **kwargs): ...
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, entries: List[str], device: torch.device = None, **kwargs) -> torch.Tensor:
|
||||
...
|
||||
def encode(self, entries: List[str], device: torch.device = None, **kwargs) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class OpenAI(BaseEncoder):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# System Packages
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Dict, List
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -30,7 +30,7 @@ def v1_telemetry(telemetry_data: List[Dict[str, str]]):
|
||||
try:
|
||||
for row in telemetry_data:
|
||||
posthog.capture(row["server_id"], "api_request", row)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Could not POST equest to new khoj telemetry server. Contact developer to get this fixed.",
|
||||
|
||||
@@ -326,7 +326,7 @@ File statistics:
|
||||
- Code examples: Yes
|
||||
- Purpose: Stress testing atomic agent updates
|
||||
|
||||
{'Additional padding content. ' * 20}
|
||||
{"Additional padding content. " * 20}
|
||||
|
||||
End of file {i}.
|
||||
"""
|
||||
|
||||
@@ -462,7 +462,7 @@ def evaluate_response_with_gemini(
|
||||
Ground Truth: {ground_truth}
|
||||
|
||||
Provide your evaluation in the following json format:
|
||||
{"explanation:" "[How you made the decision?)", "decision:" "(TRUE if response contains key information, FALSE otherwise)"}
|
||||
{"explanation:[How you made the decision?)", "decision:(TRUE if response contains key information, FALSE otherwise)"}
|
||||
"""
|
||||
gemini_api_url = (
|
||||
f"https://generativelanguage.googleapis.com/v1beta/models/{eval_model}:generateContent?key={GEMINI_API_KEY}"
|
||||
|
||||
@@ -20,7 +20,7 @@ def test_create_default_agent(default_user: KhojUser):
|
||||
assert agent.input_tools == []
|
||||
assert agent.output_modes == []
|
||||
assert agent.privacy_level == Agent.PrivacyLevel.PUBLIC
|
||||
assert agent.managed_by_admin == True
|
||||
assert agent.managed_by_admin
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -178,7 +178,7 @@ async def test_multiple_agents_with_knowledge_base_and_users(
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
|
||||
):
|
||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
await AgentAdapters.aupdate_agent(
|
||||
default_user2,
|
||||
"Test Agent",
|
||||
"Test Personality",
|
||||
@@ -290,17 +290,17 @@ async def test_large_knowledge_base_atomic_update(
|
||||
assert len(final_entries) > initial_entries_count, "Should have more entries after update"
|
||||
|
||||
# With 180 files, we should have many entries (each file creates multiple entries)
|
||||
assert (
|
||||
len(final_entries) >= expected_file_count
|
||||
), f"Expected at least {expected_file_count} entries, got {len(final_entries)}"
|
||||
assert len(final_entries) >= expected_file_count, (
|
||||
f"Expected at least {expected_file_count} entries, got {len(final_entries)}"
|
||||
)
|
||||
|
||||
# Verify no partial state - all entries should correspond to the final file set
|
||||
entry_file_paths = {entry.file_path for entry in final_entries}
|
||||
|
||||
# All file objects should have corresponding entries
|
||||
assert file_paths_in_db.issubset(
|
||||
entry_file_paths
|
||||
), "All file objects should have corresponding entries - atomic update verification"
|
||||
assert file_paths_in_db.issubset(entry_file_paths), (
|
||||
"All file objects should have corresponding entries - atomic update verification"
|
||||
)
|
||||
|
||||
# Additional stress test: verify referential integrity
|
||||
# Count entries per file to ensure no partial file processing
|
||||
@@ -333,7 +333,7 @@ async def test_concurrent_agent_updates_atomicity(
|
||||
test_files = available_files # Use all available files for the stress test
|
||||
|
||||
# Create initial agent
|
||||
agent = await AgentAdapters.aupdate_agent(
|
||||
await AgentAdapters.aupdate_agent(
|
||||
default_user2,
|
||||
"Concurrent Test Agent",
|
||||
"Test concurrent updates",
|
||||
@@ -391,14 +391,14 @@ async def test_concurrent_agent_updates_atomicity(
|
||||
file_object_paths = {fo.file_name for fo in final_file_objects}
|
||||
|
||||
# All entries should have corresponding file objects
|
||||
assert entry_file_paths.issubset(
|
||||
file_object_paths
|
||||
), "All entries should have corresponding file objects - indicates atomic update worked"
|
||||
assert entry_file_paths.issubset(file_object_paths), (
|
||||
"All entries should have corresponding file objects - indicates atomic update worked"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# If we get database integrity errors, that's actually expected behavior
|
||||
# with proper atomic transactions - they should fail cleanly rather than
|
||||
# allowing partial updates
|
||||
assert (
|
||||
"database" in str(e).lower() or "integrity" in str(e).lower()
|
||||
), f"Expected database/integrity error with concurrent updates, got: {e}"
|
||||
assert "database" in str(e).lower() or "integrity" in str(e).lower(), (
|
||||
f"Expected database/integrity error with concurrent updates, got: {e}"
|
||||
)
|
||||
|
||||
@@ -5,7 +5,6 @@ from urllib.parse import quote
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from PIL import Image
|
||||
|
||||
from khoj.configure import configure_routes, configure_search_types
|
||||
from khoj.database.adapters import EntryAdapters
|
||||
@@ -101,7 +100,7 @@ def test_update_with_invalid_content_type(client):
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/update?t=invalid_content_type", headers=headers)
|
||||
response = client.get("/api/update?t=invalid_content_type", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 422
|
||||
@@ -114,7 +113,7 @@ def test_regenerate_with_invalid_content_type(client):
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/update?force=true&t=invalid_content_type", headers=headers)
|
||||
response = client.get("/api/update?force=true&t=invalid_content_type", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 422
|
||||
@@ -238,13 +237,13 @@ def test_regenerate_with_valid_content_type(client):
|
||||
def test_regenerate_with_github_fails_without_pat(client):
|
||||
# Act
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
response = client.get(f"/api/update?force=true&t=github", headers=headers)
|
||||
response = client.get("/api/update?force=true&t=github", headers=headers)
|
||||
|
||||
# Arrange
|
||||
files = get_sample_files_data()
|
||||
|
||||
# Act
|
||||
response = client.patch(f"/api/content?t=github", files=files, headers=headers)
|
||||
response = client.patch("/api/content?t=github", files=files, headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github"
|
||||
@@ -270,7 +269,7 @@ def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
|
||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/content/types", headers=headers)
|
||||
response = client.get("/api/content/types", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
@@ -286,7 +285,7 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
||||
client = TestClient(fastapi_app)
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/content/types")
|
||||
response = client.get("/api/content/types")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
@@ -454,8 +453,8 @@ def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojAp
|
||||
headers = {"Authorization": f"Bearer {api_user2.token}"}
|
||||
|
||||
# Act
|
||||
auth_response = chat_client_with_auth.post(f"/api/chat", json={"q": query}, headers=headers)
|
||||
no_auth_response = chat_client_with_auth.post(f"/api/chat", json={"q": query})
|
||||
auth_response = chat_client_with_auth.post("/api/chat", json={"q": query}, headers=headers)
|
||||
no_auth_response = chat_client_with_auth.post("/api/chat", json={"q": query})
|
||||
|
||||
# Assert
|
||||
assert auth_response.status_code == 200
|
||||
|
||||
@@ -77,12 +77,12 @@ class TestTruncateMessage:
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert (
|
||||
len(chat_history) == 1
|
||||
), "Only most recent message should be present as it itself is larger than context size"
|
||||
assert len(truncated_chat_history[0].content) < len(
|
||||
copy_big_chat_message.content
|
||||
), "message content list should be modified"
|
||||
assert len(chat_history) == 1, (
|
||||
"Only most recent message should be present as it itself is larger than context size"
|
||||
)
|
||||
assert len(truncated_chat_history[0].content) < len(copy_big_chat_message.content), (
|
||||
"message content list should be modified"
|
||||
)
|
||||
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||
@@ -101,9 +101,9 @@ class TestTruncateMessage:
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert (
|
||||
len(chat_history) == 1
|
||||
), "Only most recent message should be present as it itself is larger than context size"
|
||||
assert len(chat_history) == 1, (
|
||||
"Only most recent message should be present as it itself is larger than context size"
|
||||
)
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
@@ -150,9 +150,9 @@ class TestTruncateMessage:
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
|
||||
assert (
|
||||
len(chat_messages) == 1
|
||||
), "Only most recent message should be present as it itself is larger than context size"
|
||||
assert len(chat_messages) == 1, (
|
||||
"Only most recent message should be present as it itself is larger than context size"
|
||||
)
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||
|
||||
@@ -172,9 +172,9 @@ class TestTruncateMessage:
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
|
||||
assert (
|
||||
len(chat_messages) == 1
|
||||
), "Only most recent message should be present as it itself is larger than context size"
|
||||
assert len(chat_messages) == 1, (
|
||||
"Only most recent message should be present as it itself is larger than context size"
|
||||
)
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
|
||||
|
||||
|
||||
@@ -162,15 +162,15 @@ def test_date_extraction():
|
||||
assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], "Expected d.m.Y structured date to be extracted"
|
||||
|
||||
extracted_dates = DateFilter().extract_dates("CLOCK: [1984-04-01 Sun 09:50]--[1984-04-01 Sun 10:10] => 24:20")
|
||||
assert extracted_dates == [
|
||||
datetime(1984, 4, 1, 0, 0, 0)
|
||||
], "Expected single deduplicated date extracted from logbook entry"
|
||||
assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], (
|
||||
"Expected single deduplicated date extracted from logbook entry"
|
||||
)
|
||||
|
||||
extracted_dates = DateFilter().extract_dates("CLOCK: [1984/03/31 mer 09:50]--[1984/04/01 mer 10:10] => 24:20")
|
||||
expected_dates = [datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 3, 31, 0, 0, 0)]
|
||||
assert all(
|
||||
[dt in extracted_dates for dt in expected_dates]
|
||||
), "Expected multiple different dates extracted from logbook entry"
|
||||
assert all([dt in extracted_dates for dt in expected_dates]), (
|
||||
"Expected multiple different dates extracted from logbook entry"
|
||||
)
|
||||
|
||||
|
||||
def test_natual_date_extraction():
|
||||
@@ -187,9 +187,9 @@ def test_natual_date_extraction():
|
||||
assert datetime(1984, 4, 4, 0, 0, 0) in extracted_dates, "Expected natural date to be extracted"
|
||||
|
||||
extracted_dates = DateFilter().extract_dates("head 11th april 1984 tail")
|
||||
assert (
|
||||
datetime(1984, 4, 11, 0, 0, 0) in extracted_dates
|
||||
), "Expected natural date with lowercase month to be extracted"
|
||||
assert datetime(1984, 4, 11, 0, 0, 0) in extracted_dates, (
|
||||
"Expected natural date with lowercase month to be extracted"
|
||||
)
|
||||
|
||||
extracted_dates = DateFilter().extract_dates("head 23rd april 84 tail")
|
||||
assert datetime(1984, 4, 23, 0, 0, 0) in extracted_dates, "Expected natural date with 2-digit year to be extracted"
|
||||
@@ -201,16 +201,16 @@ def test_natual_date_extraction():
|
||||
assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], "Expected partial natural date to be extracted"
|
||||
|
||||
extracted_dates = DateFilter().extract_dates("head Apr 1984 tail")
|
||||
assert extracted_dates == [
|
||||
datetime(1984, 4, 1, 0, 0, 0)
|
||||
], "Expected partial natural date with short month to be extracted"
|
||||
assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], (
|
||||
"Expected partial natural date with short month to be extracted"
|
||||
)
|
||||
|
||||
extracted_dates = DateFilter().extract_dates("head apr 1984 tail")
|
||||
assert extracted_dates == [
|
||||
datetime(1984, 4, 1, 0, 0, 0)
|
||||
], "Expected partial natural date with lowercase month to be extracted"
|
||||
assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], (
|
||||
"Expected partial natural date with lowercase month to be extracted"
|
||||
)
|
||||
|
||||
extracted_dates = DateFilter().extract_dates("head apr 84 tail")
|
||||
assert extracted_dates == [
|
||||
datetime(1984, 4, 1, 0, 0, 0)
|
||||
], "Expected partial natural date with 2-digit year to be extracted"
|
||||
assert extracted_dates == [datetime(1984, 4, 1, 0, 0, 0)], (
|
||||
"Expected partial natural date with 2-digit year to be extracted"
|
||||
)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
from khoj.processor.content.images.image_to_entries import ImageToEntries
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntrie
|
||||
def test_extract_markdown_with_no_headings(tmp_path):
|
||||
"Convert markdown file with no heading to entry format."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
- Bullet point 1
|
||||
- Bullet point 2
|
||||
"""
|
||||
@@ -35,7 +35,7 @@ def test_extract_markdown_with_no_headings(tmp_path):
|
||||
def test_extract_single_markdown_entry(tmp_path):
|
||||
"Convert markdown from single file to entry format."
|
||||
# Arrange
|
||||
entry = f"""### Heading
|
||||
entry = """### Heading
|
||||
\t\r
|
||||
Body Line 1
|
||||
"""
|
||||
@@ -55,7 +55,7 @@ def test_extract_single_markdown_entry(tmp_path):
|
||||
def test_extract_multiple_markdown_entries(tmp_path):
|
||||
"Convert multiple markdown from single file to entry format."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
### Heading 1
|
||||
\t\r
|
||||
Heading 1 Body Line 1
|
||||
@@ -81,7 +81,7 @@ def test_extract_multiple_markdown_entries(tmp_path):
|
||||
def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
"Extract markdown entries with different level headings."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
# Heading 1
|
||||
## Sub-Heading 1.1
|
||||
# Heading 2
|
||||
@@ -104,7 +104,7 @@ def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
|
||||
"Extract markdown entries when deeper child level before shallower child level."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
# Heading 1
|
||||
#### Sub-Heading 1.1
|
||||
## Sub-Heading 1.2
|
||||
@@ -129,7 +129,7 @@ def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
|
||||
def test_extract_entries_with_text_before_headings(tmp_path):
|
||||
"Extract markdown entries with some text before any headings."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
Text before headings
|
||||
# Heading 1
|
||||
body line 1
|
||||
@@ -149,15 +149,15 @@ body line 2
|
||||
assert len(entries[1]) == 3
|
||||
assert entries[1][0].raw == "\nText before headings"
|
||||
assert entries[1][1].raw == "# Heading 1\nbody line 1"
|
||||
assert (
|
||||
entries[1][2].raw == "# Heading 1\n## Heading 2\nbody line 2\n"
|
||||
), "Ensure raw entry includes heading ancestory"
|
||||
assert entries[1][2].raw == "# Heading 1\n## Heading 2\nbody line 2\n", (
|
||||
"Ensure raw entry includes heading ancestory"
|
||||
)
|
||||
|
||||
|
||||
def test_parse_markdown_file_into_single_entry_if_small(tmp_path):
|
||||
"Parse markdown file into single entry if it fits within the token limits."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
# Heading 1
|
||||
body line 1
|
||||
## Subheading 1.1
|
||||
@@ -180,7 +180,7 @@ body line 1.1
|
||||
def test_parse_markdown_entry_with_children_as_single_entry_if_small(tmp_path):
|
||||
"Parse markdown entry with child headings as single entry if it fits within the tokens limits."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
# Heading 1
|
||||
body line 1
|
||||
## Subheading 1.1
|
||||
@@ -201,13 +201,13 @@ longer body line 2.1
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert len(entries[1]) == 3
|
||||
assert (
|
||||
entries[1][0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1"
|
||||
), "First entry includes children headings"
|
||||
assert entries[1][0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1", (
|
||||
"First entry includes children headings"
|
||||
)
|
||||
assert entries[1][1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings"
|
||||
assert (
|
||||
entries[1][2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n"
|
||||
), "Third entry is second entries child heading"
|
||||
assert entries[1][2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n", (
|
||||
"Third entry is second entries child heading"
|
||||
)
|
||||
|
||||
|
||||
def test_line_number_tracking_in_recursive_split():
|
||||
@@ -252,14 +252,16 @@ def test_line_number_tracking_in_recursive_split():
|
||||
|
||||
assert entry.uri is not None, f"Entry '{entry}' has a None URI."
|
||||
assert match is not None, f"URI format is incorrect: {entry.uri}"
|
||||
assert (
|
||||
filepath_from_uri == markdown_file_path
|
||||
), f"File path in URI '{filepath_from_uri}' does not match expected '{markdown_file_path}'"
|
||||
assert filepath_from_uri == markdown_file_path, (
|
||||
f"File path in URI '{filepath_from_uri}' does not match expected '{markdown_file_path}'"
|
||||
)
|
||||
|
||||
# Ensure the first non-heading line in the compiled entry matches the line in the file
|
||||
assert (
|
||||
cleaned_first_entry_line in line_in_file.strip() or cleaned_first_entry_line in next_line_in_file.strip()
|
||||
), f"First non-heading line '{cleaned_first_entry_line}' in {entry.raw} does not match line {line_number_from_uri} in file: '{line_in_file}' or next line '{next_line_in_file}'"
|
||||
), (
|
||||
f"First non-heading line '{cleaned_first_entry_line}' in {entry.raw} does not match line {line_number_from_uri} in file: '{line_in_file}' or next line '{next_line_in_file}'"
|
||||
)
|
||||
|
||||
|
||||
# Helper Functions
|
||||
|
||||
@@ -343,12 +343,12 @@ Expenses:Food:Dining 10.00 USD""",
|
||||
"file": "Ledger.org",
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-04-01 "SuperMercado" "Bananas"
|
||||
"compiled": """2020-04-01 "SuperMercado" "Bananas"
|
||||
Expenses:Food:Groceries 10.00 USD""",
|
||||
"file": "Ledger.org",
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||
"compiled": """2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD""",
|
||||
"file": "Ledger.org",
|
||||
},
|
||||
@@ -389,12 +389,12 @@ Expenses:Food:Dining 10.00 USD""",
|
||||
"file": "Ledger.md",
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-04-01 "SuperMercado" "Bananas"
|
||||
"compiled": """2020-04-01 "SuperMercado" "Bananas"
|
||||
Expenses:Food:Groceries 10.00 USD""",
|
||||
"file": "Ledger.md",
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||
"compiled": """2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD""",
|
||||
"file": "Ledger.md",
|
||||
},
|
||||
@@ -452,17 +452,17 @@ async def test_ask_for_clarification_if_not_enough_context_in_question():
|
||||
# Arrange
|
||||
context = [
|
||||
{
|
||||
"compiled": f"""# Ramya
|
||||
"compiled": """# Ramya
|
||||
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""",
|
||||
"file": "Family.md",
|
||||
},
|
||||
{
|
||||
"compiled": f"""# Fang
|
||||
"compiled": """# Fang
|
||||
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""",
|
||||
"file": "Family.md",
|
||||
},
|
||||
{
|
||||
"compiled": f"""# Aiyla
|
||||
"compiled": """# Aiyla
|
||||
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""",
|
||||
"file": "Family.md",
|
||||
},
|
||||
@@ -497,9 +497,9 @@ async def test_agent_prompt_should_be_used(openai_agent):
|
||||
"Chat actor should ask be tuned to think like an accountant based on the agent definition"
|
||||
# Arrange
|
||||
context = [
|
||||
{"compiled": f"""I went to the store and bought some bananas for 2.20""", "file": "Ledger.md"},
|
||||
{"compiled": f"""I went to the store and bought some apples for 1.30""", "file": "Ledger.md"},
|
||||
{"compiled": f"""I went to the store and bought some oranges for 6.00""", "file": "Ledger.md"},
|
||||
{"compiled": """I went to the store and bought some bananas for 2.20""", "file": "Ledger.md"},
|
||||
{"compiled": """I went to the store and bought some apples for 1.30""", "file": "Ledger.md"},
|
||||
{"compiled": """I went to the store and bought some oranges for 6.00""", "file": "Ledger.md"},
|
||||
]
|
||||
expected_responses = ["9.50", "9.5"]
|
||||
|
||||
@@ -539,13 +539,13 @@ async def test_websearch_with_operators(chat_client, default_user2):
|
||||
responses = await generate_online_subqueries(user_query, [], None, default_user2)
|
||||
|
||||
# Assert
|
||||
assert any(
|
||||
["reddit.com/r/worldnews" in response for response in responses]
|
||||
), "Expected a search query to include site:reddit.com but got: " + str(responses)
|
||||
assert any(["reddit.com/r/worldnews" in response for response in responses]), (
|
||||
"Expected a search query to include site:reddit.com but got: " + str(responses)
|
||||
)
|
||||
|
||||
assert any(
|
||||
["site:reddit.com" in response for response in responses]
|
||||
), "Expected a search query to include site:reddit.com but got: " + str(responses)
|
||||
assert any(["site:reddit.com" in response for response in responses]), (
|
||||
"Expected a search query to include site:reddit.com but got: " + str(responses)
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@@ -559,9 +559,9 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client, default_u
|
||||
responses = await generate_online_subqueries(user_query, [], None, default_user2)
|
||||
|
||||
# Assert
|
||||
assert any(
|
||||
["site:khoj.dev" in response for response in responses]
|
||||
), "Expected search query to include site:khoj.dev but got: " + str(responses)
|
||||
assert any(["site:khoj.dev" in response for response in responses]), (
|
||||
"Expected search query to include site:khoj.dev but got: " + str(responses)
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@@ -693,9 +693,9 @@ def test_infer_task_scheduling_request(
|
||||
for expected_q in expected_qs:
|
||||
assert expected_q in inferred_query, f"Expected fragment {expected_q} in query: {inferred_query}"
|
||||
for unexpected_q in unexpected_qs:
|
||||
assert (
|
||||
unexpected_q not in inferred_query
|
||||
), f"Did not expect fragment '{unexpected_q}' in query: '{inferred_query}'"
|
||||
assert unexpected_q not in inferred_query, (
|
||||
f"Did not expect fragment '{unexpected_q}' in query: '{inferred_query}'"
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -33,7 +33,7 @@ def create_conversation(message_list, user, agent=None):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
||||
# Act
|
||||
response = chat_client.post(f"/api/chat", json={"q": "Hello, my name is Testatron. Who are you?"})
|
||||
response = chat_client.post("/api/chat", json={"q": "Hello, my name is Testatron. Who are you?"})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -50,7 +50,7 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
||||
def test_chat_with_online_content(chat_client):
|
||||
# Act
|
||||
q = "/online give me the link to paul graham's essay how to do great work"
|
||||
response = chat_client.post(f"/api/chat?", json={"q": q})
|
||||
response = chat_client.post("/api/chat?", json={"q": q})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -59,9 +59,9 @@ def test_chat_with_online_content(chat_client):
|
||||
"paulgraham.com/hwh.html",
|
||||
]
|
||||
assert response.status_code == 200
|
||||
assert any(
|
||||
[expected_response in response_message for expected_response in expected_responses]
|
||||
), f"Expected links: {expected_responses}. Actual response: {response_message}"
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
f"Expected links: {expected_responses}. Actual response: {response_message}"
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@@ -70,15 +70,15 @@ def test_chat_with_online_content(chat_client):
|
||||
def test_chat_with_online_webpage_content(chat_client):
|
||||
# Act
|
||||
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
|
||||
response = chat_client.post(f"/api/chat", json={"q": q})
|
||||
response = chat_client.post("/api/chat", json={"q": q})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
expected_responses = ["185", "1871", "horse"]
|
||||
assert response.status_code == 200
|
||||
assert any(
|
||||
[expected_response in response_message for expected_response in expected_responses]
|
||||
), f"Expected links: {expected_responses}. Actual response: {response_message}"
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
f"Expected links: {expected_responses}. Actual response: {response_message}"
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@@ -93,7 +93,7 @@ def test_answer_from_chat_history(chat_client, default_user2: KhojUser):
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.post(f"/api/chat", json={"q": "What is my name?"})
|
||||
response = chat_client.post("/api/chat", json={"q": "What is my name?"})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
@@ -120,7 +120,7 @@ def test_answer_from_currently_retrieved_content(chat_client, default_user2: Kho
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.post(f"/api/chat", json={"q": "Where was Xi Li born?"})
|
||||
response = chat_client.post("/api/chat", json={"q": "Where was Xi Li born?"})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -144,7 +144,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_n
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client_no_background.post(f"/api/chat", json={"q": "Where was I born?"})
|
||||
response = chat_client_no_background.post("/api/chat", json={"q": "Where was I born?"})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -167,7 +167,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, d
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.post(f"/api/chat", json={"q": "Where was I born?"})
|
||||
response = chat_client.post("/api/chat", json={"q": "Where was I born?"})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -192,7 +192,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.post(f"/api/chat", json={"q": "Where was I born?"})
|
||||
response = chat_client.post("/api/chat", json={"q": "Where was I born?"})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -222,7 +222,7 @@ def test_answer_using_general_command(chat_client, default_user2: KhojUser):
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response = chat_client.post("/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
@@ -240,7 +240,7 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client, default_
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.post(f"/api/chat", json={"q": query})
|
||||
response = chat_client.post("/api/chat", json={"q": query})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -258,7 +258,7 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client_no_background.post(f"/api/chat", json={"q": query})
|
||||
response = chat_client_no_background.post("/api/chat", json={"q": query})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -291,7 +291,7 @@ def test_summarize_one_file(chat_client, default_user2: KhojUser):
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
query = "/summarize"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response_message = response.json()["response"]
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
@@ -322,7 +322,7 @@ def test_summarize_extra_text(chat_client, default_user2: KhojUser):
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
query = "/summarize tell me about Xiu"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response_message = response.json()["response"]
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
@@ -349,7 +349,7 @@ def test_summarize_multiple_files(chat_client, default_user2: KhojUser):
|
||||
)
|
||||
|
||||
query = "/summarize"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -365,7 +365,7 @@ def test_summarize_no_files(chat_client, default_user2: KhojUser):
|
||||
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -400,11 +400,11 @@ def test_summarize_different_conversation(chat_client, default_user2: KhojUser):
|
||||
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation2.id)})
|
||||
response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation2.id)})
|
||||
response_message_conv2 = response.json()["response"]
|
||||
|
||||
# now make sure that the file filter is still in conversation 1
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation1.id)})
|
||||
response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation1.id)})
|
||||
response_message_conv1 = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -430,7 +430,7 @@ def test_summarize_nonexistant_file(chat_client, default_user2: KhojUser):
|
||||
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
|
||||
)
|
||||
query = urllib.parse.quote("/summarize")
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response_message = response.json()["response"]
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
@@ -462,7 +462,7 @@ def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_confi
|
||||
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -477,7 +477,7 @@ def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_confi
|
||||
def test_answer_requires_current_date_awareness(chat_client):
|
||||
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||
# Act
|
||||
response = chat_client.post(f"/api/chat", json={"q": "Where did I have lunch today?", "stream": True})
|
||||
response = chat_client.post("/api/chat", json={"q": "Where did I have lunch today?", "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
@@ -496,7 +496,7 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_clien
|
||||
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
||||
# Act
|
||||
query = "How much did I spend on dining this year?"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query})
|
||||
response = chat_client.post("/api/chat", json={"q": query})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -518,7 +518,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
||||
|
||||
# Act
|
||||
query = "Write a haiku about unit testing. Do not say anything else."
|
||||
response = chat_client.post(f"/api/chat", json={"q": query})
|
||||
response = chat_client.post("/api/chat", json={"q": query})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -536,7 +536,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
|
||||
# Act
|
||||
query = "What is the name of Namitas older son?"
|
||||
response = chat_client_no_background.post(f"/api/chat", json={"q": query})
|
||||
response = chat_client_no_background.post("/api/chat", json={"q": query})
|
||||
response_message = response.json()["response"].lower()
|
||||
|
||||
# Assert
|
||||
@@ -571,7 +571,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user
|
||||
|
||||
# Act
|
||||
query = "What is my name?"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query})
|
||||
response = chat_client.post("/api/chat", json={"q": query})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
@@ -604,9 +604,7 @@ def test_answer_in_chat_history_by_conversation_id(chat_client, default_user2: K
|
||||
|
||||
# Act
|
||||
query = "/general What is my favorite color?"
|
||||
response = chat_client.post(
|
||||
f"/api/chat", json={"q": query, "conversation_id": str(conversation.id), "stream": True}
|
||||
)
|
||||
response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id), "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
@@ -639,7 +637,7 @@ def test_answer_in_chat_history_by_conversation_id_with_agent(
|
||||
|
||||
# Act
|
||||
query = "/general What did I buy for breakfast?"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response = chat_client.post("/api/chat", json={"q": query, "conversation_id": str(conversation.id)})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert that agent only responds with the summary of spending
|
||||
@@ -657,7 +655,7 @@ def test_answer_requires_multiple_independent_searches(chat_client):
|
||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||
# Act
|
||||
query = "Is Xi Li older than Namita? Just say the older persons full name"
|
||||
response = chat_client.post(f"/api/chat", json={"q": query})
|
||||
response = chat_client.post("/api/chat", json={"q": query})
|
||||
response_message = response.json()["response"].lower()
|
||||
|
||||
# Assert
|
||||
@@ -681,7 +679,7 @@ def test_answer_using_file_filter(chat_client):
|
||||
query = (
|
||||
'Is Xi Li older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"'
|
||||
)
|
||||
response = chat_client.post(f"/api/chat", json={"q": query})
|
||||
response = chat_client.post("/api/chat", json={"q": query})
|
||||
response_message = response.json()["response"].lower()
|
||||
|
||||
# Assert
|
||||
|
||||
@@ -12,7 +12,7 @@ def test_configure_indexing_heading_only_entries(tmp_path):
|
||||
"""Ensure entries with empty body are ignored, unless explicitly configured to index heading entries.
|
||||
Property drawers not considered Body. Ignore control characters for evaluating if Body empty."""
|
||||
# Arrange
|
||||
entry = f"""*** Heading
|
||||
entry = """*** Heading
|
||||
:PROPERTIES:
|
||||
:ID: 42-42-42
|
||||
:END:
|
||||
@@ -74,7 +74,7 @@ def test_entry_split_when_exceeds_max_tokens():
|
||||
"Ensure entries with compiled words exceeding max_tokens are split."
|
||||
# Arrange
|
||||
tmp_path = "/tmp/test.org"
|
||||
entry = f"""*** Heading
|
||||
entry = """*** Heading
|
||||
\t\r
|
||||
Body Line
|
||||
"""
|
||||
@@ -99,7 +99,7 @@ def test_entry_split_when_exceeds_max_tokens():
|
||||
def test_entry_split_drops_large_words():
|
||||
"Ensure entries drops words larger than specified max word length from compiled version."
|
||||
# Arrange
|
||||
entry_text = f"""First Line
|
||||
entry_text = """First Line
|
||||
dog=1\n\r\t
|
||||
cat=10
|
||||
car=4
|
||||
@@ -124,7 +124,7 @@ book=2
|
||||
def test_parse_org_file_into_single_entry_if_small(tmp_path):
|
||||
"Parse org file into single entry if it fits within the token limits."
|
||||
# Arrange
|
||||
original_entry = f"""
|
||||
original_entry = """
|
||||
* Heading 1
|
||||
body line 1
|
||||
** Subheading 1.1
|
||||
@@ -133,7 +133,7 @@ body line 1.1
|
||||
data = {
|
||||
f"{tmp_path}": original_entry,
|
||||
}
|
||||
expected_entry = f"""
|
||||
expected_entry = """
|
||||
* Heading 1
|
||||
body line 1
|
||||
|
||||
@@ -155,7 +155,7 @@ body line 1.1
|
||||
def test_parse_org_entry_with_children_as_single_entry_if_small(tmp_path):
|
||||
"Parse org entry with child headings as single entry only if it fits within the tokens limits."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
* Heading 1
|
||||
body line 1
|
||||
** Subheading 1.1
|
||||
@@ -205,7 +205,7 @@ longer body line 2.1
|
||||
def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path):
|
||||
"Parse org sibling entries as separate entries only if it fits within the tokens limits."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
* Heading 1
|
||||
body line 1
|
||||
** Subheading 1.1
|
||||
@@ -267,7 +267,7 @@ body line 3.1
|
||||
def test_entry_with_body_to_entry(tmp_path):
|
||||
"Ensure entries with valid body text are loaded."
|
||||
# Arrange
|
||||
entry = f"""*** Heading
|
||||
entry = """*** Heading
|
||||
:PROPERTIES:
|
||||
:ID: 42-42-42
|
||||
:END:
|
||||
@@ -290,7 +290,7 @@ def test_entry_with_body_to_entry(tmp_path):
|
||||
def test_file_with_entry_after_intro_text_to_entry(tmp_path):
|
||||
"Ensure intro text before any headings is indexed."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
Intro text
|
||||
|
||||
* Entry Heading
|
||||
@@ -312,7 +312,7 @@ Intro text
|
||||
def test_file_with_no_headings_to_entry(tmp_path):
|
||||
"Ensure files with no heading, only body text are loaded."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
- Bullet point 1
|
||||
- Bullet point 2
|
||||
"""
|
||||
@@ -332,7 +332,7 @@ def test_file_with_no_headings_to_entry(tmp_path):
|
||||
def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
"Extract org entries with different level headings."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
* Heading 1
|
||||
** Sub-Heading 1.1
|
||||
* Heading 2
|
||||
@@ -396,14 +396,16 @@ def test_line_number_tracking_in_recursive_split():
|
||||
|
||||
assert entry.uri is not None, f"Entry '{entry}' has a None URI."
|
||||
assert match is not None, f"URI format is incorrect: {entry.uri}"
|
||||
assert (
|
||||
filepath_from_uri == org_file_path
|
||||
), f"File path in URI '{filepath_from_uri}' does not match expected '{org_file_path}'"
|
||||
assert filepath_from_uri == org_file_path, (
|
||||
f"File path in URI '{filepath_from_uri}' does not match expected '{org_file_path}'"
|
||||
)
|
||||
|
||||
# Ensure the first non-heading line in the compiled entry matches the line in the file
|
||||
assert (
|
||||
cleaned_first_entry_line in line_in_file.strip() or cleaned_first_entry_line in next_line_in_file.strip()
|
||||
), f"First non-heading line '{cleaned_first_entry_line}' in {entry.raw} does not match line {line_number_from_uri} in file: '{line_in_file}' or next line '{next_line_in_file}'"
|
||||
), (
|
||||
f"First non-heading line '{cleaned_first_entry_line}' in {entry.raw} does not match line {line_number_from_uri} in file: '{line_in_file}' or next line '{next_line_in_file}'"
|
||||
)
|
||||
|
||||
|
||||
# Helper Functions
|
||||
|
||||
@@ -8,7 +8,7 @@ from khoj.processor.content.org_mode import orgnode
|
||||
def test_parse_entry_with_no_headings(tmp_path):
|
||||
"Test parsing of entry with minimal fields"
|
||||
# Arrange
|
||||
entry = f"""Body Line 1"""
|
||||
entry = """Body Line 1"""
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
|
||||
# Act
|
||||
@@ -30,7 +30,7 @@ def test_parse_entry_with_no_headings(tmp_path):
|
||||
def test_parse_minimal_entry(tmp_path):
|
||||
"Test parsing of entry with minimal fields"
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
* Heading
|
||||
Body Line 1"""
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
@@ -54,7 +54,7 @@ Body Line 1"""
|
||||
def test_parse_complete_entry(tmp_path):
|
||||
"Test parsing of entry with all important fields"
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
*** DONE [#A] Heading :Tag1:TAG2:tag3:
|
||||
CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun>
|
||||
:PROPERTIES:
|
||||
@@ -89,7 +89,7 @@ Body Line 2"""
|
||||
def test_render_entry_with_property_drawer_and_empty_body(tmp_path):
|
||||
"Render heading entry with property drawer"
|
||||
# Arrange
|
||||
entry_to_render = f"""
|
||||
entry_to_render = """
|
||||
*** [#A] Heading1 :tag1:
|
||||
:PROPERTIES:
|
||||
:ID: 111-111-111-1111-1111
|
||||
@@ -116,7 +116,7 @@ def test_render_entry_with_property_drawer_and_empty_body(tmp_path):
|
||||
def test_all_links_to_entry_rendered(tmp_path):
|
||||
"Ensure all links to entry rendered in property drawer from entry"
|
||||
# Arrange
|
||||
entry = f"""
|
||||
entry = """
|
||||
*** [#A] Heading :tag1:
|
||||
:PROPERTIES:
|
||||
:ID: 123-456-789-4234-1231
|
||||
@@ -133,7 +133,7 @@ Body Line 2
|
||||
# Assert
|
||||
# SOURCE link rendered with Heading
|
||||
# ID link rendered with ID
|
||||
assert f":ID: id:123-456-789-4234-1231" in f"{entries[0]}"
|
||||
assert ":ID: id:123-456-789-4234-1231" in f"{entries[0]}"
|
||||
# LINE link rendered with line number
|
||||
assert f":LINE: file://{orgfile}#line=2" in f"{entries[0]}"
|
||||
# LINE link rendered with line number
|
||||
@@ -144,7 +144,7 @@ Body Line 2
|
||||
def test_parse_multiple_entries(tmp_path):
|
||||
"Test parsing of multiple entries"
|
||||
# Arrange
|
||||
content = f"""
|
||||
content = """
|
||||
*** FAILED [#A] Heading1 :tag1:
|
||||
CLOSED: [1984-04-01 Sun 12:00] SCHEDULED: <1984-04-01 Sun 09:00> DEADLINE: <1984-04-01 Sun>
|
||||
:PROPERTIES:
|
||||
@@ -194,7 +194,7 @@ Body 2
|
||||
def test_parse_entry_with_empty_title(tmp_path):
|
||||
"Test parsing of entry with minimal fields"
|
||||
# Arrange
|
||||
entry = f"""#+TITLE:
|
||||
entry = """#+TITLE:
|
||||
Body Line 1"""
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
|
||||
@@ -217,7 +217,7 @@ Body Line 1"""
|
||||
def test_parse_entry_with_title_and_no_headings(tmp_path):
|
||||
"Test parsing of entry with minimal fields"
|
||||
# Arrange
|
||||
entry = f"""#+TITLE: test
|
||||
entry = """#+TITLE: test
|
||||
Body Line 1"""
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
|
||||
@@ -241,7 +241,7 @@ Body Line 1"""
|
||||
def test_parse_entry_with_multiple_titles_and_no_headings(tmp_path):
|
||||
"Test parsing of entry with minimal fields"
|
||||
# Arrange
|
||||
entry = f"""#+TITLE: title1
|
||||
entry = """#+TITLE: title1
|
||||
Body Line 1
|
||||
#+TITLE: title2 """
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
@@ -266,7 +266,7 @@ Body Line 1
|
||||
def test_parse_org_with_intro_text_before_heading(tmp_path):
|
||||
"Test parsing of org file with intro text before heading"
|
||||
# Arrange
|
||||
body = f"""#+TITLE: Title
|
||||
body = """#+TITLE: Title
|
||||
intro body
|
||||
* Entry Heading
|
||||
entry body
|
||||
@@ -290,7 +290,7 @@ entry body
|
||||
def test_parse_org_with_intro_text_multiple_titles_and_heading(tmp_path):
|
||||
"Test parsing of org file with intro text, multiple titles and heading entry"
|
||||
# Arrange
|
||||
body = f"""#+TITLE: Title1
|
||||
body = """#+TITLE: Title1
|
||||
intro body
|
||||
* Entry Heading
|
||||
entry body
|
||||
@@ -314,7 +314,7 @@ entry body
|
||||
def test_parse_org_with_single_ancestor_heading(tmp_path):
|
||||
"Parse org entries with parent headings context"
|
||||
# Arrange
|
||||
body = f"""
|
||||
body = """
|
||||
* Heading 1
|
||||
body 1
|
||||
** Sub Heading 1
|
||||
@@ -336,7 +336,7 @@ body 1
|
||||
def test_parse_org_with_multiple_ancestor_headings(tmp_path):
|
||||
"Parse org entries with parent headings context"
|
||||
# Arrange
|
||||
body = f"""
|
||||
body = """
|
||||
* Heading 1
|
||||
body 1
|
||||
** Sub Heading 1
|
||||
@@ -362,7 +362,7 @@ sub sub body 1
|
||||
def test_parse_org_with_multiple_ancestor_headings_of_siblings(tmp_path):
|
||||
"Parse org entries with parent headings context"
|
||||
# Arrange
|
||||
body = f"""
|
||||
body = """
|
||||
* Heading 1
|
||||
body 1
|
||||
** Sub Heading 1
|
||||
|
||||
@@ -7,7 +7,7 @@ from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEnt
|
||||
def test_plaintext_file():
|
||||
"Convert files with no heading to jsonl."
|
||||
# Arrange
|
||||
raw_entry = f"""
|
||||
raw_entry = """
|
||||
Hi, I am a plaintext file and I have some plaintext words.
|
||||
"""
|
||||
plaintextfile = "test.txt"
|
||||
|
||||
@@ -145,9 +145,9 @@ def test_entry_chunking_by_max_tokens(tmp_path, search_config, default_user: Kho
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert (
|
||||
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
|
||||
), "new entry not split by max tokens"
|
||||
assert "Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message, (
|
||||
"new entry not split by max tokens"
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@@ -198,9 +198,9 @@ conda activate khoj
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert (
|
||||
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
|
||||
), "new entry not split by max tokens"
|
||||
assert "Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message, (
|
||||
"new entry not split by max tokens"
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user