mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Use single extract questions method across all LLMs for doc search
Using model specific extract questions was an artifact from older times, with less guidable models. New changes collate and reuse logic - Rely on send_message_to_model_wrapper for model specific formatting. - Use same prompt, context for all LLMs as can handle prompt variation. - Use response schema enforcer to ensure response consistency across models. Extract questions (because of its age) was the only tool directly within each provider code. Put it into helpers to have all the (mini) tools in one place.
This commit is contained in:
@@ -3,7 +3,7 @@ from datetime import datetime
|
||||
import pytest
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.routers.helpers import aget_data_sources_and_output_format
|
||||
from khoj.routers.helpers import aget_data_sources_and_output_format, extract_questions
|
||||
from khoj.utils.helpers import ConversationCommand
|
||||
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider
|
||||
|
||||
@@ -16,11 +16,7 @@ pytestmark = pytest.mark.skipif(
|
||||
import freezegun
|
||||
from freezegun import freeze_time
|
||||
|
||||
from khoj.processor.conversation.offline.chat_model import (
|
||||
converse_offline,
|
||||
extract_questions_offline,
|
||||
filter_questions,
|
||||
)
|
||||
from khoj.processor.conversation.offline.chat_model import converse_offline
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.utils.constants import default_offline_chat_models
|
||||
|
||||
@@ -39,7 +35,7 @@ freezegun.configure(extend_ignore_list=["transformers"])
|
||||
@freeze_time("1984-04-02", ignore=["transformers"])
|
||||
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_offline("Where did I go for dinner yesterday?", loaded_model=loaded_model)
|
||||
response = extract_questions("Where did I go for dinner yesterday?", loaded_model=loaded_model)
|
||||
|
||||
assert len(response) >= 1
|
||||
|
||||
@@ -59,7 +55,7 @@ def test_extract_question_with_date_filter_from_relative_day(loaded_model):
|
||||
@freeze_time("1984-04-02", ignore=["transformers"])
|
||||
def test_extract_question_with_date_filter_from_relative_month(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_offline("Which countries did I visit last month?", loaded_model=loaded_model)
|
||||
response = extract_questions("Which countries did I visit last month?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
assert len(response) >= 1
|
||||
@@ -81,7 +77,7 @@ def test_extract_question_with_date_filter_from_relative_month(loaded_model):
|
||||
@freeze_time("1984-04-02", ignore=["transformers"])
|
||||
def test_extract_question_with_date_filter_from_relative_year():
|
||||
# Act
|
||||
response = extract_questions_offline("Which countries have I visited this year?")
|
||||
response = extract_questions("Which countries have I visited this year?")
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
@@ -99,7 +95,7 @@ def test_extract_question_with_date_filter_from_relative_year():
|
||||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_explicit_questions_from_message(loaded_model):
|
||||
# Act
|
||||
responses = extract_questions_offline("What is the Sun? What is the Moon?", loaded_model=loaded_model)
|
||||
responses = extract_questions("What is the Sun? What is the Moon?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
assert len(responses) >= 2
|
||||
@@ -111,7 +107,7 @@ def test_extract_multiple_explicit_questions_from_message(loaded_model):
|
||||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_implicit_questions_from_message(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_offline("Is Carl taller than Ross?", loaded_model=loaded_model)
|
||||
response = extract_questions("Is Carl taller than Ross?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
expected_responses = ["height", "taller", "shorter", "heights", "who"]
|
||||
@@ -133,7 +129,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
||||
query = "Does he have any sons?"
|
||||
|
||||
# Act
|
||||
response = extract_questions_offline(
|
||||
response = extract_questions(
|
||||
query,
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
@@ -179,7 +175,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model):
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions_offline(
|
||||
response = extract_questions(
|
||||
"Is she a Doctor?",
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
@@ -208,7 +204,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions_offline(
|
||||
response = extract_questions(
|
||||
"What was the Pizza place we ate at over there?",
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
@@ -609,15 +605,3 @@ def test_chat_does_not_exceed_prompt_size(loaded_model):
|
||||
assert prompt_size_exceeded_error not in response, (
|
||||
"Expected chat response to be within prompt limits, but got exceeded error: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_filter_questions():
|
||||
test_questions = [
|
||||
"I don't know how to answer that",
|
||||
"I cannot answer anything about the nuclear secrets",
|
||||
"Who is on the basketball team?",
|
||||
]
|
||||
filtered_questions = filter_questions(test_questions)
|
||||
assert len(filtered_questions) == 1
|
||||
assert filtered_questions[0] == "Who is on the basketball team?"
|
||||
|
||||
Reference in New Issue
Block a user