Use single extract questions method across all LLMs for doc search

Using model specific extract questions was an artifact from older
times, with less guidable models.

New changes collate and reuse logic
- Rely on send_message_to_model_wrapper for model specific formatting.
- Use same prompt, context for all LLMs as can handle prompt variation.
- Use response schema enforcer to ensure response consistency across models.

Extract questions (because of its age) was the only tool directly within
each provider code. Put it into helpers to have all the (mini) tools
in one place.
This commit is contained in:
Debanjum
2025-06-05 02:15:58 -07:00
parent c2cd92a454
commit 2f4160e24b
8 changed files with 109 additions and 575 deletions

View File

@@ -3,7 +3,7 @@ from datetime import datetime
import pytest
from khoj.database.models import ChatModel
from khoj.routers.helpers import aget_data_sources_and_output_format
from khoj.routers.helpers import aget_data_sources_and_output_format, extract_questions
from khoj.utils.helpers import ConversationCommand
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider
@@ -16,11 +16,7 @@ pytestmark = pytest.mark.skipif(
import freezegun
from freezegun import freeze_time
from khoj.processor.conversation.offline.chat_model import (
converse_offline,
extract_questions_offline,
filter_questions,
)
from khoj.processor.conversation.offline.chat_model import converse_offline
from khoj.processor.conversation.offline.utils import download_model
from khoj.utils.constants import default_offline_chat_models
@@ -39,7 +35,7 @@ freezegun.configure(extend_ignore_list=["transformers"])
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
# Act
response = extract_questions_offline("Where did I go for dinner yesterday?", loaded_model=loaded_model)
response = extract_questions("Where did I go for dinner yesterday?", loaded_model=loaded_model)
assert len(response) >= 1
@@ -59,7 +55,7 @@ def test_extract_question_with_date_filter_from_relative_day(loaded_model):
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_month(loaded_model):
# Act
response = extract_questions_offline("Which countries did I visit last month?", loaded_model=loaded_model)
response = extract_questions("Which countries did I visit last month?", loaded_model=loaded_model)
# Assert
assert len(response) >= 1
@@ -81,7 +77,7 @@ def test_extract_question_with_date_filter_from_relative_month(loaded_model):
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_year():
# Act
response = extract_questions_offline("Which countries have I visited this year?")
response = extract_questions("Which countries have I visited this year?")
# Assert
expected_responses = [
@@ -99,7 +95,7 @@ def test_extract_question_with_date_filter_from_relative_year():
@pytest.mark.chatquality
def test_extract_multiple_explicit_questions_from_message(loaded_model):
# Act
responses = extract_questions_offline("What is the Sun? What is the Moon?", loaded_model=loaded_model)
responses = extract_questions("What is the Sun? What is the Moon?", loaded_model=loaded_model)
# Assert
assert len(responses) >= 2
@@ -111,7 +107,7 @@ def test_extract_multiple_explicit_questions_from_message(loaded_model):
@pytest.mark.chatquality
def test_extract_multiple_implicit_questions_from_message(loaded_model):
# Act
response = extract_questions_offline("Is Carl taller than Ross?", loaded_model=loaded_model)
response = extract_questions("Is Carl taller than Ross?", loaded_model=loaded_model)
# Assert
expected_responses = ["height", "taller", "shorter", "heights", "who"]
@@ -133,7 +129,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
query = "Does he have any sons?"
# Act
response = extract_questions_offline(
response = extract_questions(
query,
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
@@ -179,7 +175,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model):
]
# Act
response = extract_questions_offline(
response = extract_questions(
"Is she a Doctor?",
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
@@ -208,7 +204,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
]
# Act
response = extract_questions_offline(
response = extract_questions(
"What was the Pizza place we ate at over there?",
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
@@ -609,15 +605,3 @@ def test_chat_does_not_exceed_prompt_size(loaded_model):
assert prompt_size_exceeded_error not in response, (
"Expected chat response to be within prompt limits, but got exceeded error: " + response
)
# ----------------------------------------------------------------------------------------------------
def test_filter_questions():
test_questions = [
"I don't know how to answer that",
"I cannot answer anything about the nuclear secrets",
"Who is on the basketball team?",
]
filtered_questions = filter_questions(test_questions)
assert len(filtered_questions) == 1
assert filtered_questions[0] == "Who is on the basketball team?"