mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Use single extract questions method across all LLMs for doc search
Using model specific extract questions was an artifact from older times, with less guidable models. New changes collate and reuse logic - Rely on send_message_to_model_wrapper for model specific formatting. - Use same prompt, context for all LLMs as can handle prompt variation. - Use response schema enforcer to ensure response consistency across models. Extract questions (because of its age) was the only tool directly within each provider code. Put it into helpers to have all the (mini) tools in one place.
This commit is contained in:
@@ -1,23 +1,16 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import pyjson5
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.anthropic.utils import (
|
||||
anthropic_chat_completion_with_backoff,
|
||||
anthropic_completion_with_backoff,
|
||||
format_messages_for_anthropic,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
ResponseWithThought,
|
||||
clean_json,
|
||||
construct_question_history,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
@@ -28,89 +21,6 @@ from khoj.utils.yaml import yaml_dump
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_questions_anthropic(
|
||||
text,
|
||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
api_key=None,
|
||||
api_base_url=None,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Chat History
|
||||
chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant")
|
||||
|
||||
# Get dates relative to today for prompt creation
|
||||
today = datetime.today()
|
||||
current_new_year = today.replace(month=1, day=1)
|
||||
last_new_year = current_new_year.replace(year=today.year - 1)
|
||||
|
||||
system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
|
||||
current_date=today.strftime("%Y-%m-%d"),
|
||||
day_of_week=today.strftime("%A"),
|
||||
current_month=today.strftime("%Y-%m"),
|
||||
last_new_year=last_new_year.strftime("%Y"),
|
||||
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
|
||||
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
|
||||
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
prompt = prompts.extract_questions_anthropic_user_message.format(
|
||||
chat_history=chat_history_str,
|
||||
text=text,
|
||||
)
|
||||
|
||||
content = construct_structured_message(
|
||||
message=prompt,
|
||||
images=query_images,
|
||||
model_type=ChatModel.ModelType.ANTHROPIC,
|
||||
vision_enabled=vision_enabled,
|
||||
attached_file_context=query_files,
|
||||
)
|
||||
|
||||
messages = [ChatMessage(content=content, role="user")]
|
||||
|
||||
response = anthropic_completion_with_backoff(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
response_type="json_object",
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Claude's Response
|
||||
try:
|
||||
response = clean_json(response)
|
||||
response = pyjson5.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response:
|
||||
logger.error(f"Invalid response for constructing subqueries: {response}")
|
||||
return [text]
|
||||
return response
|
||||
except:
|
||||
logger.warning(f"Claude returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||
questions = [text]
|
||||
logger.debug(f"Extracted Questions by Claude: {questions}")
|
||||
return questions
|
||||
|
||||
|
||||
def anthropic_send_message_to_model(
|
||||
messages, api_key, api_base_url, model, response_type="text", response_schema=None, deepthought=False, tracer={}
|
||||
):
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import pyjson5
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.google.utils import (
|
||||
gemini_chat_completion_with_backoff,
|
||||
@@ -15,9 +11,6 @@ from khoj.processor.conversation.google.utils import (
|
||||
from khoj.processor.conversation.utils import (
|
||||
OperatorRun,
|
||||
ResponseWithThought,
|
||||
clean_json,
|
||||
construct_question_history,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
@@ -28,96 +21,6 @@ from khoj.utils.yaml import yaml_dump
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_questions_gemini(
|
||||
text,
|
||||
model: Optional[str] = "gemini-2.0-flash",
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
api_key=None,
|
||||
api_base_url=None,
|
||||
max_tokens=None,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Chat History
|
||||
chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant")
|
||||
|
||||
# Get dates relative to today for prompt creation
|
||||
today = datetime.today()
|
||||
current_new_year = today.replace(month=1, day=1)
|
||||
last_new_year = current_new_year.replace(year=today.year - 1)
|
||||
|
||||
system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
|
||||
current_date=today.strftime("%Y-%m-%d"),
|
||||
day_of_week=today.strftime("%A"),
|
||||
current_month=today.strftime("%Y-%m"),
|
||||
last_new_year=last_new_year.strftime("%Y"),
|
||||
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
|
||||
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
|
||||
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
prompt = prompts.extract_questions_anthropic_user_message.format(
|
||||
chat_history=chat_history_str,
|
||||
text=text,
|
||||
)
|
||||
|
||||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
images=query_images,
|
||||
model_type=ChatModel.ModelType.GOOGLE,
|
||||
vision_enabled=vision_enabled,
|
||||
attached_file_context=query_files,
|
||||
)
|
||||
|
||||
messages = []
|
||||
|
||||
messages.append(ChatMessage(content=prompt, role="user"))
|
||||
messages.append(ChatMessage(content=system_prompt, role="system"))
|
||||
|
||||
class DocumentQueries(BaseModel):
|
||||
queries: List[str] = Field(..., min_items=1)
|
||||
|
||||
response = gemini_send_message_to_model(
|
||||
messages,
|
||||
api_key,
|
||||
model,
|
||||
api_base_url=api_base_url,
|
||||
response_type="json_object",
|
||||
response_schema=DocumentQueries,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Gemini's Response
|
||||
try:
|
||||
response = clean_json(response)
|
||||
response = pyjson5.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response:
|
||||
logger.error(f"Invalid response for constructing subqueries: {response}")
|
||||
return [text]
|
||||
return response
|
||||
except:
|
||||
logger.warning(f"Gemini returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||
questions = [text]
|
||||
logger.debug(f"Extracted Questions by Gemini: {questions}")
|
||||
return questions
|
||||
|
||||
|
||||
def gemini_send_message_to_model(
|
||||
messages,
|
||||
api_key,
|
||||
|
||||
@@ -1,28 +1,24 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
from time import perf_counter
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Union
|
||||
|
||||
import pyjson5
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from llama_cpp import Llama
|
||||
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
ResponseWithThought,
|
||||
clean_json,
|
||||
commit_conversation_trace,
|
||||
construct_question_history,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.helpers import (
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
@@ -34,114 +30,6 @@ from khoj.utils.yaml import yaml_dump
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_questions_offline(
|
||||
text: str,
|
||||
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
use_history: bool = True,
|
||||
should_extract_questions: bool = True,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
max_prompt_size: int = None,
|
||||
temperature: float = 0.7,
|
||||
personality_context: Optional[str] = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
all_questions = text.split("? ")
|
||||
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
|
||||
|
||||
if not should_extract_questions:
|
||||
return all_questions
|
||||
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history_str = construct_question_history(chat_history, include_query=False) if use_history else ""
|
||||
|
||||
# Get dates relative to today for prompt creation
|
||||
today = datetime.today()
|
||||
yesterday = (today - timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
last_year = today.year - 1
|
||||
example_questions = prompts.extract_questions_offline.format(
|
||||
query=text,
|
||||
chat_history=chat_history_str,
|
||||
current_date=today.strftime("%Y-%m-%d"),
|
||||
day_of_week=today.strftime("%A"),
|
||||
current_month=today.strftime("%Y-%m"),
|
||||
yesterday_date=yesterday,
|
||||
last_year=last_year,
|
||||
this_year=today.year,
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
messages = generate_chatml_messages_with_context(
|
||||
example_questions,
|
||||
model_name=model,
|
||||
loaded_model=offline_chat_model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
model_type=ChatModel.ModelType.OFFLINE,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
response = send_message_to_model_offline(
|
||||
messages,
|
||||
loaded_model=offline_chat_model,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
temperature=temperature,
|
||||
response_type="json_object",
|
||||
tracer=tracer,
|
||||
)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
|
||||
# Extract and clean the chat model's response
|
||||
try:
|
||||
response = clean_json(empty_escape_sequences)
|
||||
response = pyjson5.loads(response)
|
||||
questions = [q.strip() for q in response["queries"] if q.strip()]
|
||||
questions = filter_questions(questions)
|
||||
except:
|
||||
logger.warning(f"Llama returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||
return all_questions
|
||||
logger.debug(f"Questions extracted by {model}: {questions}")
|
||||
return questions
|
||||
|
||||
|
||||
def filter_questions(questions: List[str]):
|
||||
# Skip questions that seem to be apologizing for not being able to answer the question
|
||||
hint_words = [
|
||||
"sorry",
|
||||
"apologize",
|
||||
"unable",
|
||||
"can't",
|
||||
"cannot",
|
||||
"don't know",
|
||||
"don't understand",
|
||||
"do not know",
|
||||
"do not understand",
|
||||
]
|
||||
filtered_questions = set()
|
||||
for q in questions:
|
||||
if not any([word in q.lower() for word in hint_words]) and not is_none_or_empty(q):
|
||||
filtered_questions.add(q)
|
||||
|
||||
return list(filtered_questions)
|
||||
|
||||
|
||||
async def converse_offline(
|
||||
# Query
|
||||
user_query: str,
|
||||
@@ -324,7 +212,7 @@ def send_message_to_model_offline(
|
||||
if streaming:
|
||||
return response
|
||||
|
||||
response_text = response["choices"][0]["message"].get("content", "")
|
||||
response_text: str = response["choices"][0]["message"].get("content", "")
|
||||
|
||||
# Save conversation trace for non-streaming responses
|
||||
# Streamed responses need to be saved by the calling function
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import pyjson5
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from openai.lib._pydantic import _ensure_strict_json_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.openai.utils import (
|
||||
chat_completion_with_backoff,
|
||||
@@ -18,9 +16,6 @@ from khoj.processor.conversation.utils import (
|
||||
JsonSupport,
|
||||
OperatorRun,
|
||||
ResponseWithThought,
|
||||
clean_json,
|
||||
construct_question_history,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
@@ -31,88 +26,6 @@ from khoj.utils.yaml import yaml_dump
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_questions(
|
||||
text,
|
||||
model: Optional[str] = "gpt-4o-mini",
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
api_key=None,
|
||||
api_base_url=None,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Chat History
|
||||
chat_history_str = construct_question_history(chat_history)
|
||||
|
||||
# Get dates relative to today for prompt creation
|
||||
today = datetime.today()
|
||||
current_new_year = today.replace(month=1, day=1)
|
||||
last_new_year = current_new_year.replace(year=today.year - 1)
|
||||
|
||||
prompt = prompts.extract_questions.format(
|
||||
current_date=today.strftime("%Y-%m-%d"),
|
||||
day_of_week=today.strftime("%A"),
|
||||
current_month=today.strftime("%Y-%m"),
|
||||
last_new_year=last_new_year.strftime("%Y"),
|
||||
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
|
||||
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
|
||||
bob_tom_age_difference={current_new_year.year - 1984 - 30},
|
||||
bob_age={current_new_year.year - 1984},
|
||||
chat_history=chat_history_str,
|
||||
text=text,
|
||||
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
images=query_images,
|
||||
model_type=ChatModel.ModelType.OPENAI,
|
||||
vision_enabled=vision_enabled,
|
||||
attached_file_context=query_files,
|
||||
)
|
||||
|
||||
messages = []
|
||||
messages.append(ChatMessage(content=prompt, role="user"))
|
||||
|
||||
response = send_message_to_model(
|
||||
messages,
|
||||
api_key,
|
||||
model,
|
||||
response_type="json_object",
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
try:
|
||||
response = clean_json(response)
|
||||
response = pyjson5.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response:
|
||||
logger.error(f"Invalid response for constructing subqueries: {response}")
|
||||
return [text]
|
||||
return response
|
||||
except:
|
||||
logger.warning(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||
questions = [text]
|
||||
|
||||
logger.debug(f"Extracted Questions by GPT: {questions}")
|
||||
return questions
|
||||
|
||||
|
||||
def send_message_to_model(
|
||||
messages,
|
||||
api_key,
|
||||
|
||||
@@ -549,68 +549,7 @@ Q: {query}
|
||||
)
|
||||
|
||||
|
||||
extract_questions = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes and documents.
|
||||
Construct search queries to retrieve relevant information to answer the user's question.
|
||||
- You will be provided example and actual past user questions(Q), search queries(Khoj) and answers(A) for context.
|
||||
- Add as much context from the previous questions and answers as required into your search queries.
|
||||
- Break your search down into multiple search queries from a diverse set of lenses to retrieve all related documents.
|
||||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
|
||||
{personality_context}
|
||||
What searches will you perform to answer the user's question? Respond with search queries as list of strings in a JSON object.
|
||||
Current Date: {day_of_week}, {current_date}
|
||||
User's Location: {location}
|
||||
{username}
|
||||
|
||||
Examples
|
||||
---
|
||||
Q: How was my trip to Cambodia?
|
||||
Khoj: {{"queries": ["How was my trip to Cambodia?", "Angkor Wat temple visit", "Flight to Phnom Penh", "Expenses in Cambodia", "Stay in Cambodia"]}}
|
||||
A: The trip was amazing. You went to the Angkor Wat temple and it was beautiful.
|
||||
|
||||
Q: Who did i visit that temple with?
|
||||
Khoj: {{"queries": ["Who did I visit the Angkor Wat Temple in Cambodia with?"]}}
|
||||
A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi.
|
||||
|
||||
Q: What national parks did I go to last year?
|
||||
Khoj: {{"queries": ["National park I visited in {last_new_year} dt>='{last_new_year_date}' dt<'{current_new_year_date}'"]}}
|
||||
A: You visited the Grand Canyon and Yellowstone National Park in {last_new_year}.
|
||||
|
||||
Q: How can you help me?
|
||||
Khoj: {{"queries": ["Social relationships", "Physical and mental health", "Education and career", "Personal life goals and habits"]}}
|
||||
A: I can help you live healthier and happier across work and personal life
|
||||
|
||||
Q: How many tennis balls fit in the back of a 2002 Honda Civic?
|
||||
Khoj: {{"queries": ["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]}}
|
||||
A: 1085 tennis balls will fit in the trunk of a Honda Civic
|
||||
|
||||
Q: Share some random, interesting experiences from this month
|
||||
Khoj: {{"queries": ["Exciting travel adventures from {current_month}", "Fun social events dt>='{current_month}-01' dt<'{current_date}'", "Intense emotional experiences in {current_month}"]}}
|
||||
A: You had a great time at the local beach with your friends, attended a music concert and had a deep conversation with your friend, Khalid.
|
||||
|
||||
Q: Is Bob older than Tom?
|
||||
Khoj: {{"queries": ["When was Bob born?", "What is Tom's age?"]}}
|
||||
A: Yes, Bob is older than Tom. As Bob was born on 1984-01-01 and Tom is 30 years old.
|
||||
|
||||
Q: What is their age difference?
|
||||
Khoj: {{"queries": ["What is Bob's age?", "What is Tom's age?"]}}
|
||||
A: Bob is {bob_tom_age_difference} years older than Tom. As Bob is {bob_age} years old and Tom is 30 years old.
|
||||
|
||||
Q: Who all did I meet here yesterday?
|
||||
Khoj: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
|
||||
A: Yesterday's note mentions your visit to your local beach with Ram and Shyam.
|
||||
|
||||
Actual
|
||||
---
|
||||
{chat_history}
|
||||
Q: {text}
|
||||
Khoj:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
extract_questions_anthropic_system_prompt = PromptTemplate.from_template(
|
||||
extract_questions_system_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes.
|
||||
Construct search queries to retrieve relevant information to answer the user's question.
|
||||
@@ -651,7 +590,7 @@ A: You had a great time at the local beach with your friends, attended a music c
|
||||
""".strip()
|
||||
)
|
||||
|
||||
extract_questions_anthropic_user_message = PromptTemplate.from_template(
|
||||
extract_questions_user_message = PromptTemplate.from_template(
|
||||
"""
|
||||
Here's our most recent chat history:
|
||||
{chat_history}
|
||||
|
||||
@@ -82,21 +82,17 @@ from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
||||
anthropic_send_message_to_model,
|
||||
converse_anthropic,
|
||||
extract_questions_anthropic,
|
||||
)
|
||||
from khoj.processor.conversation.google.gemini_chat import (
|
||||
converse_gemini,
|
||||
extract_questions_gemini,
|
||||
gemini_send_message_to_model,
|
||||
)
|
||||
from khoj.processor.conversation.offline.chat_model import (
|
||||
converse_offline,
|
||||
extract_questions_offline,
|
||||
send_message_to_model_offline,
|
||||
)
|
||||
from khoj.processor.conversation.openai.gpt import (
|
||||
converse_openai,
|
||||
extract_questions,
|
||||
send_message_to_model,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
@@ -107,6 +103,7 @@ from khoj.processor.conversation.utils import (
|
||||
clean_json,
|
||||
clean_mermaidjs,
|
||||
construct_chat_history,
|
||||
construct_question_history,
|
||||
defilter_query,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
@@ -1222,7 +1219,6 @@ async def search_documents(
|
||||
return
|
||||
|
||||
filters_in_query += " ".join([f'file:"{filter}"' for filter in conversation.file_filters])
|
||||
using_offline_chat = False
|
||||
if is_none_or_empty(filters_in_query):
|
||||
logger.debug(f"Filters in query: {filters_in_query}")
|
||||
|
||||
@@ -1230,89 +1226,18 @@ async def search_documents(
|
||||
|
||||
# Infer search queries from user message
|
||||
with timer("Extracting search queries took", logger):
|
||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||
chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
||||
vision_enabled = chat_model.vision_enabled
|
||||
inferred_queries = await extract_questions(
|
||||
query=defiltered_query,
|
||||
user=user,
|
||||
personality_context=personality_context,
|
||||
chat_history=chat_history,
|
||||
location_data=location_data,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||
using_offline_chat = True
|
||||
chat_model_name = chat_model.name
|
||||
max_tokens = chat_model.max_prompt_size
|
||||
if state.offline_chat_processor_config is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
|
||||
inferred_queries = extract_questions_offline(
|
||||
defiltered_query,
|
||||
model=chat_model,
|
||||
loaded_model=loaded_model,
|
||||
chat_history=chat_history,
|
||||
should_extract_questions=True,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
personality_context=personality_context,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_model_name = chat_model.name
|
||||
inferred_queries = extract_questions(
|
||||
defiltered_query,
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=base_url,
|
||||
chat_history=chat_history,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_model_name = chat_model.name
|
||||
inferred_queries = extract_questions_anthropic(
|
||||
defiltered_query,
|
||||
query_images=query_images,
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
chat_history=chat_history,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_model_name = chat_model.name
|
||||
inferred_queries = extract_questions_gemini(
|
||||
defiltered_query,
|
||||
query_images=query_images,
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
chat_history=chat_history,
|
||||
location_data=location_data,
|
||||
max_tokens=chat_model.max_prompt_size,
|
||||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
# Collate search results as context for the LLM
|
||||
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
|
||||
with timer("Searching knowledge base took", logger):
|
||||
search_results = []
|
||||
@@ -1322,12 +1247,11 @@ async def search_documents(
|
||||
async for event in send_status_func(f"**Searching Documents for:** {inferred_queries_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
for query in inferred_queries:
|
||||
n_items = min(n, 3) if using_offline_chat else n
|
||||
search_results.extend(
|
||||
await execute_search(
|
||||
user if not should_limit_to_agent_knowledge else None,
|
||||
f"{query} {filters_in_query}",
|
||||
n=n_items,
|
||||
n=n,
|
||||
t=state.SearchType.All,
|
||||
r=True,
|
||||
max_distance=d,
|
||||
@@ -1344,6 +1268,78 @@ async def search_documents(
|
||||
yield compiled_references, inferred_queries, defiltered_query
|
||||
|
||||
|
||||
async def extract_questions(
|
||||
query: str,
|
||||
user: KhojUser,
|
||||
personality_context: str = "",
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
location_data: LocationData = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Infer document search queries from user message and provided context
|
||||
"""
|
||||
# Shared context setup
|
||||
location = f"{location_data}" if location_data else "N/A"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Date variables for prompt formatting
|
||||
today = datetime.today()
|
||||
current_new_year = today.replace(month=1, day=1)
|
||||
last_new_year = current_new_year.replace(year=today.year - 1)
|
||||
yesterday = (today - timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
# Common prompt setup for API-based models (using Anthropic prompts for consistency)
|
||||
chat_history_str = construct_question_history(chat_history, query_prefix="User", agent_name="Assistant")
|
||||
|
||||
system_prompt = prompts.extract_questions_system_prompt.format(
|
||||
current_date=today.strftime("%Y-%m-%d"),
|
||||
day_of_week=today.strftime("%A"),
|
||||
current_month=today.strftime("%Y-%m"),
|
||||
last_new_year=last_new_year.strftime("%Y"),
|
||||
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
|
||||
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
|
||||
yesterday_date=yesterday,
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
prompt = prompts.extract_questions_user_message.format(text=query, chat_history=chat_history_str)
|
||||
|
||||
class DocumentQueries(BaseModel):
|
||||
"""Choose searches to run on user documents."""
|
||||
|
||||
queries: List[str] = Field(..., min_items=1, description="List of search queries to run on user documents.")
|
||||
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
system_message=system_prompt,
|
||||
query=prompt,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
chat_history=chat_history,
|
||||
response_type="json_object",
|
||||
response_schema=DocumentQueries,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract questions from the response
|
||||
try:
|
||||
response = clean_json(raw_response)
|
||||
response = pyjson5.loads(response)
|
||||
queries = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(queries, list) or not queries:
|
||||
logger.error(f"Invalid response for constructing subqueries: {response}")
|
||||
return [query]
|
||||
return queries
|
||||
except:
|
||||
logger.warning(f"LLM returned invalid JSON. Falling back to using user message as search query.")
|
||||
return [query]
|
||||
|
||||
|
||||
async def execute_search(
|
||||
user: KhojUser,
|
||||
q: str,
|
||||
|
||||
@@ -3,7 +3,7 @@ from datetime import datetime
|
||||
import pytest
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.routers.helpers import aget_data_sources_and_output_format
|
||||
from khoj.routers.helpers import aget_data_sources_and_output_format, extract_questions
|
||||
from khoj.utils.helpers import ConversationCommand
|
||||
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider
|
||||
|
||||
@@ -16,11 +16,7 @@ pytestmark = pytest.mark.skipif(
|
||||
import freezegun
|
||||
from freezegun import freeze_time
|
||||
|
||||
from khoj.processor.conversation.offline.chat_model import (
|
||||
converse_offline,
|
||||
extract_questions_offline,
|
||||
filter_questions,
|
||||
)
|
||||
from khoj.processor.conversation.offline.chat_model import converse_offline
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.utils.constants import default_offline_chat_models
|
||||
|
||||
@@ -39,7 +35,7 @@ freezegun.configure(extend_ignore_list=["transformers"])
|
||||
@freeze_time("1984-04-02", ignore=["transformers"])
|
||||
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_offline("Where did I go for dinner yesterday?", loaded_model=loaded_model)
|
||||
response = extract_questions("Where did I go for dinner yesterday?", loaded_model=loaded_model)
|
||||
|
||||
assert len(response) >= 1
|
||||
|
||||
@@ -59,7 +55,7 @@ def test_extract_question_with_date_filter_from_relative_day(loaded_model):
|
||||
@freeze_time("1984-04-02", ignore=["transformers"])
|
||||
def test_extract_question_with_date_filter_from_relative_month(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_offline("Which countries did I visit last month?", loaded_model=loaded_model)
|
||||
response = extract_questions("Which countries did I visit last month?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
assert len(response) >= 1
|
||||
@@ -81,7 +77,7 @@ def test_extract_question_with_date_filter_from_relative_month(loaded_model):
|
||||
@freeze_time("1984-04-02", ignore=["transformers"])
|
||||
def test_extract_question_with_date_filter_from_relative_year():
|
||||
# Act
|
||||
response = extract_questions_offline("Which countries have I visited this year?")
|
||||
response = extract_questions("Which countries have I visited this year?")
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
@@ -99,7 +95,7 @@ def test_extract_question_with_date_filter_from_relative_year():
|
||||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_explicit_questions_from_message(loaded_model):
|
||||
# Act
|
||||
responses = extract_questions_offline("What is the Sun? What is the Moon?", loaded_model=loaded_model)
|
||||
responses = extract_questions("What is the Sun? What is the Moon?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
assert len(responses) >= 2
|
||||
@@ -111,7 +107,7 @@ def test_extract_multiple_explicit_questions_from_message(loaded_model):
|
||||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_implicit_questions_from_message(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_offline("Is Carl taller than Ross?", loaded_model=loaded_model)
|
||||
response = extract_questions("Is Carl taller than Ross?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
expected_responses = ["height", "taller", "shorter", "heights", "who"]
|
||||
@@ -133,7 +129,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
||||
query = "Does he have any sons?"
|
||||
|
||||
# Act
|
||||
response = extract_questions_offline(
|
||||
response = extract_questions(
|
||||
query,
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
@@ -179,7 +175,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model):
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions_offline(
|
||||
response = extract_questions(
|
||||
"Is she a Doctor?",
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
@@ -208,7 +204,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions_offline(
|
||||
response = extract_questions(
|
||||
"What was the Pizza place we ate at over there?",
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
@@ -609,15 +605,3 @@ def test_chat_does_not_exceed_prompt_size(loaded_model):
|
||||
assert prompt_size_exceeded_error not in response, (
|
||||
"Expected chat response to be within prompt limits, but got exceeded error: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_filter_questions():
|
||||
test_questions = [
|
||||
"I don't know how to answer that",
|
||||
"I cannot answer anything about the nuclear secrets",
|
||||
"Who is on the basketball team?",
|
||||
]
|
||||
filtered_questions = filter_questions(test_questions)
|
||||
assert len(filtered_questions) == 1
|
||||
assert filtered_questions[0] == "Who is on the basketball team?"
|
||||
|
||||
@@ -4,10 +4,11 @@ import freezegun
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from khoj.processor.conversation.openai.gpt import converse_openai, extract_questions
|
||||
from khoj.processor.conversation.openai.gpt import converse_openai
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import (
|
||||
aget_data_sources_and_output_format,
|
||||
extract_questions,
|
||||
generate_online_subqueries,
|
||||
infer_webpage_urls,
|
||||
schedule_query,
|
||||
|
||||
Reference in New Issue
Block a user