Make streaming optional for the /chat endpoint (#287)

* Update the /chat endpoint to conditionally support streaming

- If streams are enabled, return the threadgenerator as it does currently
- If stream is disabled, return a JSON response with the response/compiled references separated out
- Correspondingly, update the chat.html UI to use the streamed API, as well as Obsidian
- Rename chat/init/ to chat/history

* Update khoj.el to use the /history endpoint

- Update corresponding unit tests to use stream=true

* Remove & from call to /chat for obsidian

* Abstract functions out into a helpers.py file and clean up some of the error-catching
This commit is contained in:
sabaimran
2023-07-09 10:12:09 -07:00
committed by GitHub
parent 0a86220d42
commit 4c135ea316
7 changed files with 177 additions and 100 deletions

View File

@@ -174,6 +174,8 @@ def test_generate_search_query_with_date_and_context_from_chat_history():
expected_responses = [
("dt>='2000-04-01'", "dt<'2000-05-01'"),
("dt>='2000-04-01'", "dt<='2000-04-30'"),
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
('dt>="2000-04-01"', 'dt<="2000-04-30"'),
]
assert len(response) == 1
assert "Masai Mara" in response[0]

View File

@@ -40,7 +40,7 @@ def populate_chat_history(message_list):
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"')
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@@ -62,7 +62,7 @@ def test_answer_from_chat_history(chat_client):
populate_chat_history(message_list)
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"')
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@@ -157,7 +157,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
populate_chat_history(message_list)
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"')
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@@ -175,7 +175,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
def test_answer_requires_current_date_awareness(chat_client):
"Chat actor should be able to answer questions relative to current date using provided notes"
# Act
response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"')
response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@@ -192,7 +192,7 @@ def test_answer_requires_current_date_awareness(chat_client):
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
# Act
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"')
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@@ -212,7 +212,9 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
populate_chat_history(message_list)
# Act
response = chat_client.get(f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."')
response = chat_client.get(
f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."&stream=true'
)
response_message = response.content.decode("utf-8")
# Assert
@@ -229,7 +231,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"')
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@@ -258,7 +260,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
populate_chat_history(message_list)
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"')
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@@ -274,11 +276,11 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
def test_answer_requires_multiple_independent_searches(chat_client):
"Chat director should be able to answer by doing multiple independent searches for required information"
# Act
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"')
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["he is older than namita", "xi is older than namita"]
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected Xi is older than Namita, but got: " + response_message