From afd162de01b05e021a872a7b0357dd5dee79e668 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Tue, 4 Jul 2023 12:47:50 -0700 Subject: [PATCH] Add reference notes to result response from GPT when streaming is completed - NOTE: results are still not being saved to conversation history --- src/khoj/interface/web/chat.html | 19 +++++++++++++++---- src/khoj/processor/conversation/gpt.py | 1 + src/khoj/processor/conversation/utils.py | 12 ++++++++---- src/khoj/routers/api.py | 12 ++++++++++-- 4 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 213f20ad..7d93c836 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -89,10 +89,21 @@ } const chunk = decoder.decode(value, { stream: true }); - new_response_text.innerHTML += chunk; - console.log(`Received ${chunk.length} bytes of data`); - console.log(`Chunk: ${chunk}`); - readStream(); + + if (chunk.startsWith("### compiled references:")) { + const rawReference = chunk.split("### compiled references:")[1]; + const rawReferenceAsJson = JSON.parse(rawReference); + let polishedReference = rawReferenceAsJson.map((reference, index) => generateReference(reference, index)) + .join(","); + + new_response_text.innerHTML += polishedReference; + } else { + new_response_text.innerHTML += chunk; + console.log(`Received ${chunk.length} bytes of data`); + console.log(`Chunk: ${chunk}`); + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + readStream(); + } }); } readStream(); diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py index 74f13305..1c993c37 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/gpt.py @@ -172,6 +172,7 @@ def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo", logger.debug(f"Conversation Context for GPT: {messages}") return chat_completion_with_backoff( messages=messages, + compiled_references=references, model_name=model, temperature=temperature, openai_api_key=api_key, diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 0730f3f2..f2a87b3c 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -6,6 +6,7 @@ from typing import Any, Optional from uuid import UUID import asyncio from threading import Thread +import json # External Packages from langchain.chat_models import ChatOpenAI @@ -36,8 +37,9 @@ max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192} class ThreadedGenerator: - def __init__(self): + def __init__(self, compiled_references): self.queue = queue.Queue() + self.compiled_references = compiled_references def __iter__(self): return self @@ -45,13 +47,15 @@ class ThreadedGenerator: def __next__(self): item = self.queue.get() if item is StopIteration: - raise item + raise StopIteration return item def send(self, data): self.queue.put(data) def close(self): + if self.compiled_references and len(self.compiled_references) > 0: + self.queue.put(f"### compiled references:{json.dumps(self.compiled_references)}") self.queue.put(StopIteration) @@ -101,8 +105,8 @@ def completion_with_backoff(**kwargs): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) -def chat_completion_with_backoff(messages, model_name, temperature, openai_api_key=None): - g = ThreadedGenerator() +def chat_completion_with_backoff(messages, compiled_references, model_name, temperature, openai_api_key=None): + g = ThreadedGenerator(compiled_references) t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key)) t.start() return g diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 2d77d2ef..3acbd62a 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -478,7 +478,7 @@ async def chat( result_list = [] for query in inferred_queries: result_list.extend( - await search(query, request=request, n=5, r=True, score_threshold=-5.0, dedupe=False) + await search(query, request=request, n=5, r=False, score_threshold=-5.0, dedupe=False) ) compiled_references = [item.additional["compiled"] for item in result_list] @@ -501,7 +501,15 @@ async def chat( try: with timer("Generating chat response took", logger): - gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key) + gpt_response = converse( + compiled_references, + q, + meta_log, + model=chat_model, + api_key=api_key, + chat_session=chat_session, + inferred_queries=inferred_queries, + ) except Exception as e: gpt_response = str(e)