Use single extract questions method across all LLMs for doc search

Using model specific extract questions was an artifact from older
times, with less guidable models.

New changes collate and reuse logic
- Rely on send_message_to_model_wrapper for model specific formatting.
- Use same prompt, context for all LLMs as can handle prompt variation.
- Use response schema enforcer to ensure response consistency across models.

Extract questions (because of its age) was the only tool directly within
each provider code. Put it into helpers to have all the (mini) tools
in one place.
This commit is contained in:
Debanjum
2025-06-05 02:15:58 -07:00
parent c2cd92a454
commit 2f4160e24b
8 changed files with 109 additions and 575 deletions

View File

@@ -1,23 +1,16 @@
import logging
from datetime import datetime, timedelta
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional
import pyjson5
from langchain_core.messages.chat import ChatMessage
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, ChatModel
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff,
anthropic_completion_with_backoff,
format_messages_for_anthropic,
)
from khoj.processor.conversation.utils import (
OperatorRun,
ResponseWithThought,
clean_json,
construct_question_history,
construct_structured_message,
generate_chatml_messages_with_context,
messages_to_print,
)
@@ -28,89 +21,6 @@ from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
def extract_questions_anthropic(
text,
model: Optional[str] = "claude-3-7-sonnet-latest",
chat_history: List[ChatMessageModel] = [],
api_key=None,
api_base_url=None,
location_data: LocationData = None,
user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
# Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Chat History
chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant")
# Get dates relative to today for prompt creation
today = datetime.today()
current_new_year = today.replace(month=1, day=1)
last_new_year = current_new_year.replace(year=today.year - 1)
system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
personality_context=personality_context,
)
prompt = prompts.extract_questions_anthropic_user_message.format(
chat_history=chat_history_str,
text=text,
)
content = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModel.ModelType.ANTHROPIC,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
messages = [ChatMessage(content=content, role="user")]
response = anthropic_completion_with_backoff(
messages=messages,
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
api_base_url=api_base_url,
response_type="json_object",
tracer=tracer,
)
# Extract, Clean Message from Claude's Response
try:
response = clean_json(response)
response = pyjson5.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [text]
return response
except:
logger.warning(f"Claude returned invalid JSON. Falling back to using user message as search query.\n{response}")
questions = [text]
logger.debug(f"Extracted Questions by Claude: {questions}")
return questions
def anthropic_send_message_to_model(
messages, api_key, api_base_url, model, response_type="text", response_schema=None, deepthought=False, tracer={}
):

View File

@@ -1,12 +1,8 @@
import logging
from datetime import datetime, timedelta
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional
import pyjson5
from langchain_core.messages.chat import ChatMessage
from pydantic import BaseModel, Field
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, ChatModel
from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import (
gemini_chat_completion_with_backoff,
@@ -15,9 +11,6 @@ from khoj.processor.conversation.google.utils import (
from khoj.processor.conversation.utils import (
OperatorRun,
ResponseWithThought,
clean_json,
construct_question_history,
construct_structured_message,
generate_chatml_messages_with_context,
messages_to_print,
)
@@ -28,96 +21,6 @@ from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
def extract_questions_gemini(
text,
model: Optional[str] = "gemini-2.0-flash",
chat_history: List[ChatMessageModel] = [],
api_key=None,
api_base_url=None,
max_tokens=None,
location_data: LocationData = None,
user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
# Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Chat History
chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant")
# Get dates relative to today for prompt creation
today = datetime.today()
current_new_year = today.replace(month=1, day=1)
last_new_year = current_new_year.replace(year=today.year - 1)
system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
personality_context=personality_context,
)
prompt = prompts.extract_questions_anthropic_user_message.format(
chat_history=chat_history_str,
text=text,
)
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModel.ModelType.GOOGLE,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
messages = []
messages.append(ChatMessage(content=prompt, role="user"))
messages.append(ChatMessage(content=system_prompt, role="system"))
class DocumentQueries(BaseModel):
queries: List[str] = Field(..., min_items=1)
response = gemini_send_message_to_model(
messages,
api_key,
model,
api_base_url=api_base_url,
response_type="json_object",
response_schema=DocumentQueries,
tracer=tracer,
)
# Extract, Clean Message from Gemini's Response
try:
response = clean_json(response)
response = pyjson5.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [text]
return response
except:
logger.warning(f"Gemini returned invalid JSON. Falling back to using user message as search query.\n{response}")
questions = [text]
logger.debug(f"Extracted Questions by Gemini: {questions}")
return questions
def gemini_send_message_to_model(
messages,
api_key,

View File

@@ -1,28 +1,24 @@
import asyncio
import logging
import os
from datetime import datetime, timedelta
from datetime import datetime
from threading import Thread
from time import perf_counter
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, List, Union
import pyjson5
from langchain_core.messages.chat import ChatMessage
from llama_cpp import Llama
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, ChatModel
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import (
ResponseWithThought,
clean_json,
commit_conversation_trace,
construct_question_history,
generate_chatml_messages_with_context,
messages_to_print,
)
from khoj.utils import state
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import (
is_none_or_empty,
is_promptrace_enabled,
@@ -34,114 +30,6 @@ from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
def extract_questions_offline(
text: str,
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
loaded_model: Union[Any, None] = None,
chat_history: List[ChatMessageModel] = [],
use_history: bool = True,
should_extract_questions: bool = True,
location_data: LocationData = None,
user: KhojUser = None,
max_prompt_size: int = None,
temperature: float = 0.7,
personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {},
) -> List[str]:
"""
Infer search queries to retrieve relevant notes to answer user query
"""
all_questions = text.split("? ")
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
if not should_extract_questions:
return all_questions
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history_str = construct_question_history(chat_history, include_query=False) if use_history else ""
# Get dates relative to today for prompt creation
today = datetime.today()
yesterday = (today - timedelta(days=1)).strftime("%Y-%m-%d")
last_year = today.year - 1
example_questions = prompts.extract_questions_offline.format(
query=text,
chat_history=chat_history_str,
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
yesterday_date=yesterday,
last_year=last_year,
this_year=today.year,
location=location,
username=username,
personality_context=personality_context,
)
messages = generate_chatml_messages_with_context(
example_questions,
model_name=model,
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size,
model_type=ChatModel.ModelType.OFFLINE,
query_files=query_files,
)
state.chat_lock.acquire()
try:
response = send_message_to_model_offline(
messages,
loaded_model=offline_chat_model,
model_name=model,
max_prompt_size=max_prompt_size,
temperature=temperature,
response_type="json_object",
tracer=tracer,
)
finally:
state.chat_lock.release()
# Extract and clean the chat model's response
try:
response = clean_json(empty_escape_sequences)
response = pyjson5.loads(response)
questions = [q.strip() for q in response["queries"] if q.strip()]
questions = filter_questions(questions)
except:
logger.warning(f"Llama returned invalid JSON. Falling back to using user message as search query.\n{response}")
return all_questions
logger.debug(f"Questions extracted by {model}: {questions}")
return questions
def filter_questions(questions: List[str]):
# Skip questions that seem to be apologizing for not being able to answer the question
hint_words = [
"sorry",
"apologize",
"unable",
"can't",
"cannot",
"don't know",
"don't understand",
"do not know",
"do not understand",
]
filtered_questions = set()
for q in questions:
if not any([word in q.lower() for word in hint_words]) and not is_none_or_empty(q):
filtered_questions.add(q)
return list(filtered_questions)
async def converse_offline(
# Query
user_query: str,
@@ -324,7 +212,7 @@ def send_message_to_model_offline(
if streaming:
return response
response_text = response["choices"][0]["message"].get("content", "")
response_text: str = response["choices"][0]["message"].get("content", "")
# Save conversation trace for non-streaming responses
# Streamed responses need to be saved by the calling function

View File

@@ -1,13 +1,11 @@
import logging
from datetime import datetime, timedelta
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional
import pyjson5
from langchain_core.messages.chat import ChatMessage
from openai.lib._pydantic import _ensure_strict_json_schema
from pydantic import BaseModel
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, ChatModel
from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
@@ -18,9 +16,6 @@ from khoj.processor.conversation.utils import (
JsonSupport,
OperatorRun,
ResponseWithThought,
clean_json,
construct_question_history,
construct_structured_message,
generate_chatml_messages_with_context,
messages_to_print,
)
@@ -31,88 +26,6 @@ from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
def extract_questions(
text,
model: Optional[str] = "gpt-4o-mini",
chat_history: list[ChatMessageModel] = [],
api_key=None,
api_base_url=None,
location_data: LocationData = None,
user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Chat History
chat_history_str = construct_question_history(chat_history)
# Get dates relative to today for prompt creation
today = datetime.today()
current_new_year = today.replace(month=1, day=1)
last_new_year = current_new_year.replace(year=today.year - 1)
prompt = prompts.extract_questions.format(
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
bob_tom_age_difference={current_new_year.year - 1984 - 30},
bob_age={current_new_year.year - 1984},
chat_history=chat_history_str,
text=text,
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
personality_context=personality_context,
)
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModel.ModelType.OPENAI,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
messages = []
messages.append(ChatMessage(content=prompt, role="user"))
response = send_message_to_model(
messages,
api_key,
model,
response_type="json_object",
api_base_url=api_base_url,
tracer=tracer,
)
# Extract, Clean Message from GPT's Response
try:
response = clean_json(response)
response = pyjson5.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [text]
return response
except:
logger.warning(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}")
questions = [text]
logger.debug(f"Extracted Questions by GPT: {questions}")
return questions
def send_message_to_model(
messages,
api_key,

View File

@@ -549,68 +549,7 @@ Q: {query}
)
extract_questions = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes and documents.
Construct search queries to retrieve relevant information to answer the user's question.
- You will be provided example and actual past user questions(Q), search queries(Khoj) and answers(A) for context.
- Add as much context from the previous questions and answers as required into your search queries.
- Break your search down into multiple search queries from a diverse set of lenses to retrieve all related documents.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
{personality_context}
What searches will you perform to answer the user's question? Respond with search queries as list of strings in a JSON object.
Current Date: {day_of_week}, {current_date}
User's Location: {location}
{username}
Examples
---
Q: How was my trip to Cambodia?
Khoj: {{"queries": ["How was my trip to Cambodia?", "Angkor Wat temple visit", "Flight to Phnom Penh", "Expenses in Cambodia", "Stay in Cambodia"]}}
A: The trip was amazing. You went to the Angkor Wat temple and it was beautiful.
Q: Who did i visit that temple with?
Khoj: {{"queries": ["Who did I visit the Angkor Wat Temple in Cambodia with?"]}}
A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi.
Q: What national parks did I go to last year?
Khoj: {{"queries": ["National park I visited in {last_new_year} dt>='{last_new_year_date}' dt<'{current_new_year_date}'"]}}
A: You visited the Grand Canyon and Yellowstone National Park in {last_new_year}.
Q: How can you help me?
Khoj: {{"queries": ["Social relationships", "Physical and mental health", "Education and career", "Personal life goals and habits"]}}
A: I can help you live healthier and happier across work and personal life
Q: How many tennis balls fit in the back of a 2002 Honda Civic?
Khoj: {{"queries": ["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]}}
A: 1085 tennis balls will fit in the trunk of a Honda Civic
Q: Share some random, interesting experiences from this month
Khoj: {{"queries": ["Exciting travel adventures from {current_month}", "Fun social events dt>='{current_month}-01' dt<'{current_date}'", "Intense emotional experiences in {current_month}"]}}
A: You had a great time at the local beach with your friends, attended a music concert and had a deep conversation with your friend, Khalid.
Q: Is Bob older than Tom?
Khoj: {{"queries": ["When was Bob born?", "What is Tom's age?"]}}
A: Yes, Bob is older than Tom. As Bob was born on 1984-01-01 and Tom is 30 years old.
Q: What is their age difference?
Khoj: {{"queries": ["What is Bob's age?", "What is Tom's age?"]}}
A: Bob is {bob_tom_age_difference} years older than Tom. As Bob is {bob_age} years old and Tom is 30 years old.
Q: Who all did I meet here yesterday?
Khoj: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
A: Yesterday's note mentions your visit to your local beach with Ram and Shyam.
Actual
---
{chat_history}
Q: {text}
Khoj:
""".strip()
)
extract_questions_anthropic_system_prompt = PromptTemplate.from_template(
extract_questions_system_prompt = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes.
Construct search queries to retrieve relevant information to answer the user's question.
@@ -651,7 +590,7 @@ A: You had a great time at the local beach with your friends, attended a music c
""".strip()
)
extract_questions_anthropic_user_message = PromptTemplate.from_template(
extract_questions_user_message = PromptTemplate.from_template(
"""
Here's our most recent chat history:
{chat_history}

View File

@@ -82,21 +82,17 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.anthropic_chat import (
anthropic_send_message_to_model,
converse_anthropic,
extract_questions_anthropic,
)
from khoj.processor.conversation.google.gemini_chat import (
converse_gemini,
extract_questions_gemini,
gemini_send_message_to_model,
)
from khoj.processor.conversation.offline.chat_model import (
converse_offline,
extract_questions_offline,
send_message_to_model_offline,
)
from khoj.processor.conversation.openai.gpt import (
converse_openai,
extract_questions,
send_message_to_model,
)
from khoj.processor.conversation.utils import (
@@ -107,6 +103,7 @@ from khoj.processor.conversation.utils import (
clean_json,
clean_mermaidjs,
construct_chat_history,
construct_question_history,
defilter_query,
generate_chatml_messages_with_context,
)
@@ -1222,7 +1219,6 @@ async def search_documents(
return
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
using_offline_chat = False
if is_none_or_empty(filters_in_query):
logger.debug(f"Filters in query: {filters_in_query}")
@@ -1230,89 +1226,18 @@ async def search_documents(
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
chat_model = await ConversationAdapters.aget_default_chat_model(user)
vision_enabled = chat_model.vision_enabled
inferred_queries = await extract_questions(
query=defiltered_query,
user=user,
personality_context=personality_context,
chat_history=chat_history,
location_data=location_data,
query_images=query_images,
query_files=query_files,
tracer=tracer,
)
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
using_offline_chat = True
chat_model_name = chat_model.name
max_tokens = chat_model.max_prompt_size
if state.offline_chat_processor_config is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
inferred_queries = extract_questions_offline(
defiltered_query,
model=chat_model,
loaded_model=loaded_model,
chat_history=chat_history,
should_extract_questions=True,
location_data=location_data,
user=user,
max_prompt_size=chat_model.max_prompt_size,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
api_key = chat_model.ai_model_api.api_key
base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions(
defiltered_query,
model=chat_model_name,
api_key=api_key,
api_base_url=base_url,
chat_history=chat_history,
location_data=location_data,
user=user,
query_images=query_images,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions_anthropic(
defiltered_query,
query_images=query_images,
model=chat_model_name,
api_key=api_key,
api_base_url=api_base_url,
chat_history=chat_history,
location_data=location_data,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions_gemini(
defiltered_query,
query_images=query_images,
model=chat_model_name,
api_key=api_key,
api_base_url=api_base_url,
chat_history=chat_history,
location_data=location_data,
max_tokens=chat_model.max_prompt_size,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
# Collate search results as context for GPT
# Collate search results as context for the LLM
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
with timer("Searching knowledge base took", logger):
search_results = []
@@ -1322,12 +1247,11 @@ async def search_documents(
async for event in send_status_func(f"**Searching Documents for:** {inferred_queries_str}"):
yield {ChatEvent.STATUS: event}
for query in inferred_queries:
n_items = min(n, 3) if using_offline_chat else n
search_results.extend(
await execute_search(
user if not should_limit_to_agent_knowledge else None,
f"{query} {filters_in_query}",
n=n_items,
n=n,
t=state.SearchType.All,
r=True,
max_distance=d,
@@ -1344,6 +1268,78 @@ async def search_documents(
yield compiled_references, inferred_queries, defiltered_query
async def extract_questions(
query: str,
user: KhojUser,
personality_context: str = "",
chat_history: List[ChatMessageModel] = [],
location_data: LocationData = None,
query_images: Optional[List[str]] = None,
query_files: str = None,
tracer: dict = {},
):
"""
Infer document search queries from user message and provided context
"""
# Shared context setup
location = f"{location_data}" if location_data else "N/A"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Date variables for prompt formatting
today = datetime.today()
current_new_year = today.replace(month=1, day=1)
last_new_year = current_new_year.replace(year=today.year - 1)
yesterday = (today - timedelta(days=1)).strftime("%Y-%m-%d")
# Common prompt setup for API-based models (using Anthropic prompts for consistency)
chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant")
system_prompt = prompts.extract_questions_system_prompt.format(
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=yesterday,
location=location,
username=username,
personality_context=personality_context,
)
prompt = prompts.extract_questions_user_message.format(text=query, chat_history=chat_history_str)
class DocumentQueries(BaseModel):
"""Choose searches to run on user documents."""
queries: List[str] = Field(..., min_items=1, description="List of search queries to run on user documents.")
raw_response = await send_message_to_model_wrapper(
system_message=system_prompt,
query=prompt,
query_images=query_images,
query_files=query_files,
chat_history=chat_history,
response_type="json_object",
response_schema=DocumentQueries,
user=user,
tracer=tracer,
)
# Extract questions from the response
try:
response = clean_json(raw_response)
response = pyjson5.loads(response)
queries = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(queries, list) or not queries:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [query]
return queries
except:
logger.warning(f"LLM returned invalid JSON. Falling back to using user message as search query.")
return [query]
async def execute_search(
user: KhojUser,
q: str,

View File

@@ -3,7 +3,7 @@ from datetime import datetime
import pytest
from khoj.database.models import ChatModel
from khoj.routers.helpers import aget_data_sources_and_output_format
from khoj.routers.helpers import aget_data_sources_and_output_format, extract_questions
from khoj.utils.helpers import ConversationCommand
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider
@@ -16,11 +16,7 @@ pytestmark = pytest.mark.skipif(
import freezegun
from freezegun import freeze_time
from khoj.processor.conversation.offline.chat_model import (
converse_offline,
extract_questions_offline,
filter_questions,
)
from khoj.processor.conversation.offline.chat_model import converse_offline
from khoj.processor.conversation.offline.utils import download_model
from khoj.utils.constants import default_offline_chat_models
@@ -39,7 +35,7 @@ freezegun.configure(extend_ignore_list=["transformers"])
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
# Act
response = extract_questions_offline("Where did I go for dinner yesterday?", loaded_model=loaded_model)
response = extract_questions("Where did I go for dinner yesterday?", loaded_model=loaded_model)
assert len(response) >= 1
@@ -59,7 +55,7 @@ def test_extract_question_with_date_filter_from_relative_day(loaded_model):
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_month(loaded_model):
# Act
response = extract_questions_offline("Which countries did I visit last month?", loaded_model=loaded_model)
response = extract_questions("Which countries did I visit last month?", loaded_model=loaded_model)
# Assert
assert len(response) >= 1
@@ -81,7 +77,7 @@ def test_extract_question_with_date_filter_from_relative_month(loaded_model):
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_year():
# Act
response = extract_questions_offline("Which countries have I visited this year?")
response = extract_questions("Which countries have I visited this year?")
# Assert
expected_responses = [
@@ -99,7 +95,7 @@ def test_extract_question_with_date_filter_from_relative_year():
@pytest.mark.chatquality
def test_extract_multiple_explicit_questions_from_message(loaded_model):
# Act
responses = extract_questions_offline("What is the Sun? What is the Moon?", loaded_model=loaded_model)
responses = extract_questions("What is the Sun? What is the Moon?", loaded_model=loaded_model)
# Assert
assert len(responses) >= 2
@@ -111,7 +107,7 @@ def test_extract_multiple_explicit_questions_from_message(loaded_model):
@pytest.mark.chatquality
def test_extract_multiple_implicit_questions_from_message(loaded_model):
# Act
response = extract_questions_offline("Is Carl taller than Ross?", loaded_model=loaded_model)
response = extract_questions("Is Carl taller than Ross?", loaded_model=loaded_model)
# Assert
expected_responses = ["height", "taller", "shorter", "heights", "who"]
@@ -133,7 +129,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
query = "Does he have any sons?"
# Act
response = extract_questions_offline(
response = extract_questions(
query,
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
@@ -179,7 +175,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model):
]
# Act
response = extract_questions_offline(
response = extract_questions(
"Is she a Doctor?",
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
@@ -208,7 +204,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
]
# Act
response = extract_questions_offline(
response = extract_questions(
"What was the Pizza place we ate at over there?",
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
@@ -609,15 +605,3 @@ def test_chat_does_not_exceed_prompt_size(loaded_model):
assert prompt_size_exceeded_error not in response, (
"Expected chat response to be within prompt limits, but got exceeded error: " + response
)
# ----------------------------------------------------------------------------------------------------
def test_filter_questions():
test_questions = [
"I don't know how to answer that",
"I cannot answer anything about the nuclear secrets",
"Who is on the basketball team?",
]
filtered_questions = filter_questions(test_questions)
assert len(filtered_questions) == 1
assert filtered_questions[0] == "Who is on the basketball team?"

View File

@@ -4,10 +4,11 @@ import freezegun
import pytest
from freezegun import freeze_time
from khoj.processor.conversation.openai.gpt import converse_openai, extract_questions
from khoj.processor.conversation.openai.gpt import converse_openai
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import (
aget_data_sources_and_output_format,
extract_questions,
generate_online_subqueries,
infer_webpage_urls,
schedule_query,