diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html
index 213f20ad..7d93c836 100644
--- a/src/khoj/interface/web/chat.html
+++ b/src/khoj/interface/web/chat.html
@@ -89,10 +89,21 @@
}
const chunk = decoder.decode(value, { stream: true });
- new_response_text.innerHTML += chunk;
- console.log(`Received ${chunk.length} bytes of data`);
- console.log(`Chunk: ${chunk}`);
- readStream();
+
+ if (chunk.startsWith("### compiled references:")) {
+ const rawReference = chunk.split("### compiled references:")[1];
+ const rawReferenceAsJson = JSON.parse(rawReference);
+ let polishedReference = rawReferenceAsJson.map((reference, index) => generateReference(reference, index))
+ .join(",");
+
+ new_response_text.innerHTML += polishedReference;
+ } else {
+ new_response_text.innerHTML += chunk;
+ console.log(`Received ${chunk.length} bytes of data`);
+ console.log(`Chunk: ${chunk}`);
+ document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
+ readStream();
+ }
});
}
readStream();
diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py
index 74f13305..1c993c37 100644
--- a/src/khoj/processor/conversation/gpt.py
+++ b/src/khoj/processor/conversation/gpt.py
@@ -172,6 +172,7 @@ def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo",
logger.debug(f"Conversation Context for GPT: {messages}")
return chat_completion_with_backoff(
messages=messages,
+ compiled_references=references,
model_name=model,
temperature=temperature,
openai_api_key=api_key,
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index 0730f3f2..f2a87b3c 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -6,6 +6,7 @@ from typing import Any, Optional
from uuid import UUID
import asyncio
from threading import Thread
+import json
# External Packages
from langchain.chat_models import ChatOpenAI
@@ -36,8 +37,9 @@ max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
class ThreadedGenerator:
- def __init__(self):
+ def __init__(self, compiled_references):
self.queue = queue.Queue()
+ self.compiled_references = compiled_references
def __iter__(self):
return self
@@ -45,13 +47,15 @@ class ThreadedGenerator:
def __next__(self):
item = self.queue.get()
if item is StopIteration:
- raise item
+ raise StopIteration
return item
def send(self, data):
self.queue.put(data)
def close(self):
+ if self.compiled_references and len(self.compiled_references) > 0:
+ self.queue.put(f"### compiled references:{json.dumps(self.compiled_references)}")
self.queue.put(StopIteration)
@@ -101,8 +105,8 @@ def completion_with_backoff(**kwargs):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
-def chat_completion_with_backoff(messages, model_name, temperature, openai_api_key=None):
- g = ThreadedGenerator()
+def chat_completion_with_backoff(messages, compiled_references, model_name, temperature, openai_api_key=None):
+ g = ThreadedGenerator(compiled_references)
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
t.start()
return g
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index 2d77d2ef..3acbd62a 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -478,7 +478,7 @@ async def chat(
result_list = []
for query in inferred_queries:
result_list.extend(
- await search(query, request=request, n=5, r=True, score_threshold=-5.0, dedupe=False)
+ await search(query, request=request, n=5, r=False, score_threshold=-5.0, dedupe=False)
)
compiled_references = [item.additional["compiled"] for item in result_list]
@@ -501,7 +501,15 @@ async def chat(
try:
with timer("Generating chat response took", logger):
- gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key)
+ gpt_response = converse(
+ compiled_references,
+ q,
+ meta_log,
+ model=chat_model,
+ api_key=api_key,
+ chat_session=chat_session,
+ inferred_queries=inferred_queries,
+ )
except Exception as e:
gpt_response = str(e)