diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html
index bb5d355a..213f20ad 100644
--- a/src/khoj/interface/web/chat.html
+++ b/src/khoj/interface/web/chat.html
@@ -64,14 +64,47 @@
// Generate backend API URL to execute query
let url = `/api/chat?q=${encodeURIComponent(query)}&client=web`;
- // Call specified Khoj API
+ let chat_body = document.getElementById("chat-body");
+ let new_response = document.createElement("div");
+ new_response.classList.add("chat-message", "khoj");
+ new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
+ chat_body.appendChild(new_response);
+
+ let new_response_text = document.createElement("div");
+ new_response_text.classList.add("chat-message-text", "khoj");
+ new_response.appendChild(new_response_text);
+
+ // Call specified Khoj API which returns a streamed response of type text/plain
fetch(url)
- .then(response => response.json())
- .then(data => {
- // Render message by Khoj to chat body
- console.log(data.response);
- renderMessageWithReference(data.response, "khoj", data.context);
+ .then(response => {
+ console.log(response);
+ const reader = response.body.getReader();
+ const decoder = new TextDecoder();
+
+ function readStream() {
+ reader.read().then(({ done, value }) => {
+ if (done) {
+ console.log("Stream complete");
+ return;
+ }
+
+ const chunk = decoder.decode(value, { stream: true });
+ new_response_text.innerHTML += chunk;
+ console.log(`Received ${chunk.length} bytes of data`);
+ console.log(`Chunk: ${chunk}`);
+ readStream();
+ });
+ }
+ readStream();
});
+
+
+ // fetch(url)
+ // .then(data => {
+ // // Render message by Khoj to chat body
+ // console.log(data.response);
+ // renderMessageWithReference(data.response, "khoj", data.context);
+ // });
}
function incrementalChat(event) {
@@ -82,7 +115,7 @@
}
window.onload = function () {
- fetch('/api/chat?client=web')
+ fetch('/api/chat/init?client=web')
.then(response => response.json())
.then(data => {
if (data.detail) {
diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py
index 2f29b2ef..74f13305 100644
--- a/src/khoj/processor/conversation/gpt.py
+++ b/src/khoj/processor/conversation/gpt.py
@@ -170,12 +170,18 @@ def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo",
# Get Response from GPT
logger.debug(f"Conversation Context for GPT: {messages}")
- response = chat_completion_with_backoff(
+ return chat_completion_with_backoff(
messages=messages,
model_name=model,
temperature=temperature,
openai_api_key=api_key,
)
- # Extract, Clean Message from GPT's Response
- return response.strip(empty_escape_sequences)
+ # async for tokens in chat_completion_with_backoff(
+ # messages=messages,
+ # model_name=model,
+ # temperature=temperature,
+ # openai_api_key=api_key,
+ # ):
+ # logger.info(f"Tokens from GPT: {tokens}")
+ # yield tokens
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index 03b6d9b1..0730f3f2 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -2,11 +2,19 @@
import os
import logging
from datetime import datetime
+from typing import Any, Optional
+from uuid import UUID
+import asyncio
+from threading import Thread
# External Packages
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.schema import ChatMessage
+from langchain.callbacks.base import BaseCallbackHandler
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+from langchain.callbacks import AsyncIteratorCallbackHandler
+from langchain.callbacks.base import BaseCallbackManager, AsyncCallbackHandler
import openai
import tiktoken
from tenacity import (
@@ -20,12 +28,43 @@ from tenacity import (
# Internal Packages
from khoj.utils.helpers import merge_dicts
+import queue
logger = logging.getLogger(__name__)
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
+class ThreadedGenerator:
+ def __init__(self):
+ self.queue = queue.Queue()
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ item = self.queue.get()
+ if item is StopIteration:
+ raise item
+ return item
+
+ def send(self, data):
+ self.queue.put(data)
+
+ def close(self):
+ self.queue.put(StopIteration)
+
+
+class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
+ def __init__(self, gen: ThreadedGenerator):
+ super().__init__()
+ self.gen = gen
+
+ def on_llm_new_token(self, token: str, **kwargs) -> Any:
+ logger.debug(f"New Token: {token}")
+ self.gen.send(token)
+
+
@retry(
retry=(
retry_if_exception_type(openai.error.Timeout)
@@ -63,14 +102,28 @@ def completion_with_backoff(**kwargs):
reraise=True,
)
def chat_completion_with_backoff(messages, model_name, temperature, openai_api_key=None):
+ g = ThreadedGenerator()
+ t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
+ t.start()
+ return g
+
+
+def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
+ callback_handler = StreamingChatCallbackHandler(g)
chat = ChatOpenAI(
+ streaming=True,
+ verbose=True,
+ callback_manager=BaseCallbackManager([callback_handler]),
model_name=model_name,
temperature=temperature,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
request_timeout=20,
max_retries=1,
)
- return chat(messages).content
+
+ chat(messages=messages)
+
+ g.close()
def generate_chatml_messages_with_context(
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index a32f420b..2d77d2ef 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -34,6 +34,7 @@ from khoj.utils.rawconfig import (
from khoj.utils.state import SearchType
from khoj.utils import state, constants
from khoj.utils.yaml import save_config_to_file_updated_state
+from fastapi.responses import StreamingResponse
# Initialize Router
api = APIRouter()
@@ -393,8 +394,8 @@ def update(
return {"status": "ok", "message": "khoj reloaded"}
-@api.get("/chat")
-async def chat(
+@api.get("/chat/init")
+def chat_init(
request: Request,
q: Optional[str] = None,
client: Optional[str] = None,
@@ -411,13 +412,52 @@ async def chat(
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
)
+ # Load Conversation History
+ meta_log = state.processor_config.conversation.meta_log
+
+ user_state = {
+ "client_host": request.client.host,
+ "user_agent": user_agent or "unknown",
+ "referer": referer or "unknown",
+ "host": host or "unknown",
+ }
+
+ state.telemetry += [
+ log_telemetry(
+ telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state
+ )
+ ]
+
+ # If user query is empty, return chat history
+ if not q:
+ return {"status": "ok", "response": meta_log.get("chat", [])}
+
+
+@api.get("/chat", response_class=StreamingResponse)
+async def chat(
+ request: Request,
+ q: Optional[str] = None,
+ client: Optional[str] = None,
+ user_agent: Optional[str] = Header(None),
+ referer: Optional[str] = Header(None),
+ host: Optional[str] = Header(None),
+) -> StreamingResponse:
+ if (
+ state.processor_config is None
+ or state.processor_config.conversation is None
+ or state.processor_config.conversation.openai_api_key is None
+ ):
+ raise HTTPException(
+ status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
+ )
+
# Load Conversation History
chat_session = state.processor_config.conversation.chat_session
meta_log = state.processor_config.conversation.meta_log
# If user query is empty, return chat history
if not q:
- return {"status": "ok", "response": meta_log.get("chat", [])}
+ return StreamingResponse(None)
# Initialize Variables
api_key = state.processor_config.conversation.openai_api_key
@@ -446,24 +486,6 @@ async def chat(
conversation_type = "notes" if compiled_references else "general"
logger.debug(f"Conversation Type: {conversation_type}")
- try:
- with timer("Generating chat response took", logger):
- gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key)
- status = "ok"
- except Exception as e:
- gpt_response = str(e)
- status = "error"
-
- # Update Conversation History
- state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
- state.processor_config.conversation.meta_log["chat"] = message_to_log(
- q,
- gpt_response,
- user_message_metadata={"created": user_message_time},
- khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
- conversation_log=meta_log.get("chat", []),
- )
-
user_state = {
"client_host": request.client.host,
"user_agent": user_agent or "unknown",
@@ -477,4 +499,20 @@ async def chat(
)
]
- return {"status": status, "response": gpt_response, "context": compiled_references}
+ try:
+ with timer("Generating chat response took", logger):
+ gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key)
+ except Exception as e:
+ gpt_response = str(e)
+
+ # Update Conversation History
+ # state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
+ # state.processor_config.conversation.meta_log["chat"] = message_to_log(
+ # q,
+ # gpt_response,
+ # user_message_metadata={"created": user_message_time},
+ # khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
+ # conversation_log=meta_log.get("chat", []),
+ # )
+
+ return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)