Use single extract questions method across all LLMs for doc search

Using model specific extract questions was an artifact from older
times, with less guidable models.

New changes collate and reuse logic
- Rely on send_message_to_model_wrapper for model specific formatting.
- Use same prompt, context for all LLMs as can handle prompt variation.
- Use response schema enforcer to ensure response consistency across models.

Extract questions (because of its age) was the only tool directly within
each provider code. Put it into helpers to have all the (mini) tools
in one place.
This commit is contained in:
Debanjum
2025-06-05 02:15:58 -07:00
parent c2cd92a454
commit 2f4160e24b
8 changed files with 109 additions and 575 deletions

View File

@@ -1,23 +1,16 @@
import logging
from datetime import datetime, timedelta
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional
import pyjson5
from langchain_core.messages.chat import ChatMessage
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, ChatModel
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff,
anthropic_completion_with_backoff,
format_messages_for_anthropic,
)
from khoj.processor.conversation.utils import (
OperatorRun,
ResponseWithThought,
clean_json,
construct_question_history,
construct_structured_message,
generate_chatml_messages_with_context,
messages_to_print,
)
@@ -28,89 +21,6 @@ from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
def extract_questions_anthropic(
text,
model: Optional[str] = "claude-3-7-sonnet-latest",
chat_history: List[ChatMessageModel] = [],
api_key=None,
api_base_url=None,
location_data: LocationData = None,
user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
# Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Chat History
chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant")
# Get dates relative to today for prompt creation
today = datetime.today()
current_new_year = today.replace(month=1, day=1)
last_new_year = current_new_year.replace(year=today.year - 1)
system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
personality_context=personality_context,
)
prompt = prompts.extract_questions_anthropic_user_message.format(
chat_history=chat_history_str,
text=text,
)
content = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModel.ModelType.ANTHROPIC,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
messages = [ChatMessage(content=content, role="user")]
response = anthropic_completion_with_backoff(
messages=messages,
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
api_base_url=api_base_url,
response_type="json_object",
tracer=tracer,
)
# Extract, Clean Message from Claude's Response
try:
response = clean_json(response)
response = pyjson5.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [text]
return response
except:
logger.warning(f"Claude returned invalid JSON. Falling back to using user message as search query.\n{response}")
questions = [text]
logger.debug(f"Extracted Questions by Claude: {questions}")
return questions
def anthropic_send_message_to_model(
messages, api_key, api_base_url, model, response_type="text", response_schema=None, deepthought=False, tracer={}
):

View File

@@ -1,12 +1,8 @@
import logging
from datetime import datetime, timedelta
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional
import pyjson5
from langchain_core.messages.chat import ChatMessage
from pydantic import BaseModel, Field
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, ChatModel
from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import (
gemini_chat_completion_with_backoff,
@@ -15,9 +11,6 @@ from khoj.processor.conversation.google.utils import (
from khoj.processor.conversation.utils import (
OperatorRun,
ResponseWithThought,
clean_json,
construct_question_history,
construct_structured_message,
generate_chatml_messages_with_context,
messages_to_print,
)
@@ -28,96 +21,6 @@ from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
def extract_questions_gemini(
text,
model: Optional[str] = "gemini-2.0-flash",
chat_history: List[ChatMessageModel] = [],
api_key=None,
api_base_url=None,
max_tokens=None,
location_data: LocationData = None,
user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
# Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Chat History
chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant")
# Get dates relative to today for prompt creation
today = datetime.today()
current_new_year = today.replace(month=1, day=1)
last_new_year = current_new_year.replace(year=today.year - 1)
system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
personality_context=personality_context,
)
prompt = prompts.extract_questions_anthropic_user_message.format(
chat_history=chat_history_str,
text=text,
)
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModel.ModelType.GOOGLE,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
messages = []
messages.append(ChatMessage(content=prompt, role="user"))
messages.append(ChatMessage(content=system_prompt, role="system"))
class DocumentQueries(BaseModel):
queries: List[str] = Field(..., min_items=1)
response = gemini_send_message_to_model(
messages,
api_key,
model,
api_base_url=api_base_url,
response_type="json_object",
response_schema=DocumentQueries,
tracer=tracer,
)
# Extract, Clean Message from Gemini's Response
try:
response = clean_json(response)
response = pyjson5.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [text]
return response
except:
logger.warning(f"Gemini returned invalid JSON. Falling back to using user message as search query.\n{response}")
questions = [text]
logger.debug(f"Extracted Questions by Gemini: {questions}")
return questions
def gemini_send_message_to_model(
messages,
api_key,

View File

@@ -1,28 +1,24 @@
import asyncio
import logging
import os
from datetime import datetime, timedelta
from datetime import datetime
from threading import Thread
from time import perf_counter
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, List, Union
import pyjson5
from langchain_core.messages.chat import ChatMessage
from llama_cpp import Llama
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, ChatModel
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import (
ResponseWithThought,
clean_json,
commit_conversation_trace,
construct_question_history,
generate_chatml_messages_with_context,
messages_to_print,
)
from khoj.utils import state
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import (
is_none_or_empty,
is_promptrace_enabled,
@@ -34,114 +30,6 @@ from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
def extract_questions_offline(
text: str,
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
loaded_model: Union[Any, None] = None,
chat_history: List[ChatMessageModel] = [],
use_history: bool = True,
should_extract_questions: bool = True,
location_data: LocationData = None,
user: KhojUser = None,
max_prompt_size: int = None,
temperature: float = 0.7,
personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {},
) -> List[str]:
"""
Infer search queries to retrieve relevant notes to answer user query
"""
all_questions = text.split("? ")
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
if not should_extract_questions:
return all_questions
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history_str = construct_question_history(chat_history, include_query=False) if use_history else ""
# Get dates relative to today for prompt creation
today = datetime.today()
yesterday = (today - timedelta(days=1)).strftime("%Y-%m-%d")
last_year = today.year - 1
example_questions = prompts.extract_questions_offline.format(
query=text,
chat_history=chat_history_str,
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
yesterday_date=yesterday,
last_year=last_year,
this_year=today.year,
location=location,
username=username,
personality_context=personality_context,
)
messages = generate_chatml_messages_with_context(
example_questions,
model_name=model,
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size,
model_type=ChatModel.ModelType.OFFLINE,
query_files=query_files,
)
state.chat_lock.acquire()
try:
response = send_message_to_model_offline(
messages,
loaded_model=offline_chat_model,
model_name=model,
max_prompt_size=max_prompt_size,
temperature=temperature,
response_type="json_object",
tracer=tracer,
)
finally:
state.chat_lock.release()
# Extract and clean the chat model's response
try:
response = clean_json(empty_escape_sequences)
response = pyjson5.loads(response)
questions = [q.strip() for q in response["queries"] if q.strip()]
questions = filter_questions(questions)
except:
logger.warning(f"Llama returned invalid JSON. Falling back to using user message as search query.\n{response}")
return all_questions
logger.debug(f"Questions extracted by {model}: {questions}")
return questions
def filter_questions(questions: List[str]):
# Skip questions that seem to be apologizing for not being able to answer the question
hint_words = [
"sorry",
"apologize",
"unable",
"can't",
"cannot",
"don't know",
"don't understand",
"do not know",
"do not understand",
]
filtered_questions = set()
for q in questions:
if not any([word in q.lower() for word in hint_words]) and not is_none_or_empty(q):
filtered_questions.add(q)
return list(filtered_questions)
async def converse_offline(
# Query
user_query: str,
@@ -324,7 +212,7 @@ def send_message_to_model_offline(
if streaming:
return response
response_text = response["choices"][0]["message"].get("content", "")
response_text: str = response["choices"][0]["message"].get("content", "")
# Save conversation trace for non-streaming responses
# Streamed responses need to be saved by the calling function

View File

@@ -1,13 +1,11 @@
import logging
from datetime import datetime, timedelta
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional
import pyjson5
from langchain_core.messages.chat import ChatMessage
from openai.lib._pydantic import _ensure_strict_json_schema
from pydantic import BaseModel
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
from khoj.database.models import Agent, ChatMessageModel, ChatModel
from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
@@ -18,9 +16,6 @@ from khoj.processor.conversation.utils import (
JsonSupport,
OperatorRun,
ResponseWithThought,
clean_json,
construct_question_history,
construct_structured_message,
generate_chatml_messages_with_context,
messages_to_print,
)
@@ -31,88 +26,6 @@ from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
def extract_questions(
text,
model: Optional[str] = "gpt-4o-mini",
chat_history: list[ChatMessageModel] = [],
api_key=None,
api_base_url=None,
location_data: LocationData = None,
user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Chat History
chat_history_str = construct_question_history(chat_history)
# Get dates relative to today for prompt creation
today = datetime.today()
current_new_year = today.replace(month=1, day=1)
last_new_year = current_new_year.replace(year=today.year - 1)
prompt = prompts.extract_questions.format(
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
bob_tom_age_difference={current_new_year.year - 1984 - 30},
bob_age={current_new_year.year - 1984},
chat_history=chat_history_str,
text=text,
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
personality_context=personality_context,
)
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModel.ModelType.OPENAI,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
messages = []
messages.append(ChatMessage(content=prompt, role="user"))
response = send_message_to_model(
messages,
api_key,
model,
response_type="json_object",
api_base_url=api_base_url,
tracer=tracer,
)
# Extract, Clean Message from GPT's Response
try:
response = clean_json(response)
response = pyjson5.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [text]
return response
except:
logger.warning(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}")
questions = [text]
logger.debug(f"Extracted Questions by GPT: {questions}")
return questions
def send_message_to_model(
messages,
api_key,

View File

@@ -549,68 +549,7 @@ Q: {query}
)
extract_questions = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes and documents.
Construct search queries to retrieve relevant information to answer the user's question.
- You will be provided example and actual past user questions(Q), search queries(Khoj) and answers(A) for context.
- Add as much context from the previous questions and answers as required into your search queries.
- Break your search down into multiple search queries from a diverse set of lenses to retrieve all related documents.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
{personality_context}
What searches will you perform to answer the user's question? Respond with search queries as list of strings in a JSON object.
Current Date: {day_of_week}, {current_date}
User's Location: {location}
{username}
Examples
---
Q: How was my trip to Cambodia?
Khoj: {{"queries": ["How was my trip to Cambodia?", "Angkor Wat temple visit", "Flight to Phnom Penh", "Expenses in Cambodia", "Stay in Cambodia"]}}
A: The trip was amazing. You went to the Angkor Wat temple and it was beautiful.
Q: Who did i visit that temple with?
Khoj: {{"queries": ["Who did I visit the Angkor Wat Temple in Cambodia with?"]}}
A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi.
Q: What national parks did I go to last year?
Khoj: {{"queries": ["National park I visited in {last_new_year} dt>='{last_new_year_date}' dt<'{current_new_year_date}'"]}}
A: You visited the Grand Canyon and Yellowstone National Park in {last_new_year}.
Q: How can you help me?
Khoj: {{"queries": ["Social relationships", "Physical and mental health", "Education and career", "Personal life goals and habits"]}}
A: I can help you live healthier and happier across work and personal life
Q: How many tennis balls fit in the back of a 2002 Honda Civic?
Khoj: {{"queries": ["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]}}
A: 1085 tennis balls will fit in the trunk of a Honda Civic
Q: Share some random, interesting experiences from this month
Khoj: {{"queries": ["Exciting travel adventures from {current_month}", "Fun social events dt>='{current_month}-01' dt<'{current_date}'", "Intense emotional experiences in {current_month}"]}}
A: You had a great time at the local beach with your friends, attended a music concert and had a deep conversation with your friend, Khalid.
Q: Is Bob older than Tom?
Khoj: {{"queries": ["When was Bob born?", "What is Tom's age?"]}}
A: Yes, Bob is older than Tom. As Bob was born on 1984-01-01 and Tom is 30 years old.
Q: What is their age difference?
Khoj: {{"queries": ["What is Bob's age?", "What is Tom's age?"]}}
A: Bob is {bob_tom_age_difference} years older than Tom. As Bob is {bob_age} years old and Tom is 30 years old.
Q: Who all did I meet here yesterday?
Khoj: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
A: Yesterday's note mentions your visit to your local beach with Ram and Shyam.
Actual
---
{chat_history}
Q: {text}
Khoj:
""".strip()
)
extract_questions_anthropic_system_prompt = PromptTemplate.from_template(
extract_questions_system_prompt = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes.
Construct search queries to retrieve relevant information to answer the user's question.
@@ -651,7 +590,7 @@ A: You had a great time at the local beach with your friends, attended a music c
""".strip()
)
extract_questions_anthropic_user_message = PromptTemplate.from_template(
extract_questions_user_message = PromptTemplate.from_template(
"""
Here's our most recent chat history:
{chat_history}

View File

@@ -82,21 +82,17 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.anthropic_chat import (
anthropic_send_message_to_model,
converse_anthropic,
extract_questions_anthropic,
)
from khoj.processor.conversation.google.gemini_chat import (
converse_gemini,
extract_questions_gemini,
gemini_send_message_to_model,
)
from khoj.processor.conversation.offline.chat_model import (
converse_offline,
extract_questions_offline,
send_message_to_model_offline,
)
from khoj.processor.conversation.openai.gpt import (
converse_openai,
extract_questions,
send_message_to_model,
)
from khoj.processor.conversation.utils import (
@@ -107,6 +103,7 @@ from khoj.processor.conversation.utils import (
clean_json,
clean_mermaidjs,
construct_chat_history,
construct_question_history,
defilter_query,
generate_chatml_messages_with_context,
)
@@ -1222,7 +1219,6 @@ async def search_documents(
return
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
using_offline_chat = False
if is_none_or_empty(filters_in_query):
logger.debug(f"Filters in query: {filters_in_query}")
@@ -1230,89 +1226,18 @@ async def search_documents(
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
chat_model = await ConversationAdapters.aget_default_chat_model(user)
vision_enabled = chat_model.vision_enabled
inferred_queries = await extract_questions(
query=defiltered_query,
user=user,
personality_context=personality_context,
chat_history=chat_history,
location_data=location_data,
query_images=query_images,
query_files=query_files,
tracer=tracer,
)
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
using_offline_chat = True
chat_model_name = chat_model.name
max_tokens = chat_model.max_prompt_size
if state.offline_chat_processor_config is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
inferred_queries = extract_questions_offline(
defiltered_query,
model=chat_model,
loaded_model=loaded_model,
chat_history=chat_history,
should_extract_questions=True,
location_data=location_data,
user=user,
max_prompt_size=chat_model.max_prompt_size,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
api_key = chat_model.ai_model_api.api_key
base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions(
defiltered_query,
model=chat_model_name,
api_key=api_key,
api_base_url=base_url,
chat_history=chat_history,
location_data=location_data,
user=user,
query_images=query_images,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions_anthropic(
defiltered_query,
query_images=query_images,
model=chat_model_name,
api_key=api_key,
api_base_url=api_base_url,
chat_history=chat_history,
location_data=location_data,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions_gemini(
defiltered_query,
query_images=query_images,
model=chat_model_name,
api_key=api_key,
api_base_url=api_base_url,
chat_history=chat_history,
location_data=location_data,
max_tokens=chat_model.max_prompt_size,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
# Collate search results as context for GPT
# Collate search results as context for the LLM
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
with timer("Searching knowledge base took", logger):
search_results = []
@@ -1322,12 +1247,11 @@ async def search_documents(
async for event in send_status_func(f"**Searching Documents for:** {inferred_queries_str}"):
yield {ChatEvent.STATUS: event}
for query in inferred_queries:
n_items = min(n, 3) if using_offline_chat else n
search_results.extend(
await execute_search(
user if not should_limit_to_agent_knowledge else None,
f"{query} {filters_in_query}",
n=n_items,
n=n,
t=state.SearchType.All,
r=True,
max_distance=d,
@@ -1344,6 +1268,78 @@ async def search_documents(
yield compiled_references, inferred_queries, defiltered_query
async def extract_questions(
query: str,
user: KhojUser,
personality_context: str = "",
chat_history: List[ChatMessageModel] = [],
location_data: LocationData = None,
query_images: Optional[List[str]] = None,
query_files: str = None,
tracer: dict = {},
):
"""
Infer document search queries from user message and provided context
"""
# Shared context setup
location = f"{location_data}" if location_data else "N/A"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Date variables for prompt formatting
today = datetime.today()
current_new_year = today.replace(month=1, day=1)
last_new_year = current_new_year.replace(year=today.year - 1)
yesterday = (today - timedelta(days=1)).strftime("%Y-%m-%d")
# Common prompt setup for API-based models (using Anthropic prompts for consistency)
chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant")
system_prompt = prompts.extract_questions_system_prompt.format(
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=yesterday,
location=location,
username=username,
personality_context=personality_context,
)
prompt = prompts.extract_questions_user_message.format(text=query, chat_history=chat_history_str)
class DocumentQueries(BaseModel):
"""Choose searches to run on user documents."""
queries: List[str] = Field(..., min_items=1, description="List of search queries to run on user documents.")
raw_response = await send_message_to_model_wrapper(
system_message=system_prompt,
query=prompt,
query_images=query_images,
query_files=query_files,
chat_history=chat_history,
response_type="json_object",
response_schema=DocumentQueries,
user=user,
tracer=tracer,
)
# Extract questions from the response
try:
response = clean_json(raw_response)
response = pyjson5.loads(response)
queries = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(queries, list) or not queries:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [query]
return queries
except:
logger.warning(f"LLM returned invalid JSON. Falling back to using user message as search query.")
return [query]
async def execute_search(
user: KhojUser,
q: str,