diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index bb5d355a..213f20ad 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -64,14 +64,47 @@ // Generate backend API URL to execute query let url = `/api/chat?q=${encodeURIComponent(query)}&client=web`; - // Call specified Khoj API + let chat_body = document.getElementById("chat-body"); + let new_response = document.createElement("div"); + new_response.classList.add("chat-message", "khoj"); + new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date()); + chat_body.appendChild(new_response); + + let new_response_text = document.createElement("div"); + new_response_text.classList.add("chat-message-text", "khoj"); + new_response.appendChild(new_response_text); + + // Call specified Khoj API which returns a streamed response of type text/plain fetch(url) - .then(response => response.json()) - .then(data => { - // Render message by Khoj to chat body - console.log(data.response); - renderMessageWithReference(data.response, "khoj", data.context); + .then(response => { + console.log(response); + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + + function readStream() { + reader.read().then(({ done, value }) => { + if (done) { + console.log("Stream complete"); + return; + } + + const chunk = decoder.decode(value, { stream: true }); + new_response_text.innerHTML += chunk; + console.log(`Received ${chunk.length} bytes of data`); + console.log(`Chunk: ${chunk}`); + readStream(); + }); + } + readStream(); }); + + + // fetch(url) + // .then(data => { + // // Render message by Khoj to chat body + // console.log(data.response); + // renderMessageWithReference(data.response, "khoj", data.context); + // }); } function incrementalChat(event) { @@ -82,7 +115,7 @@ } window.onload = function () { - fetch('/api/chat?client=web') + fetch('/api/chat/init?client=web') .then(response => response.json()) .then(data => { if (data.detail) { diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py index 2f29b2ef..74f13305 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/gpt.py @@ -170,12 +170,18 @@ def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo", # Get Response from GPT logger.debug(f"Conversation Context for GPT: {messages}") - response = chat_completion_with_backoff( + return chat_completion_with_backoff( messages=messages, model_name=model, temperature=temperature, openai_api_key=api_key, ) - # Extract, Clean Message from GPT's Response - return response.strip(empty_escape_sequences) + # async for tokens in chat_completion_with_backoff( + # messages=messages, + # model_name=model, + # temperature=temperature, + # openai_api_key=api_key, + # ): + # logger.info(f"Tokens from GPT: {tokens}") + # yield tokens diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 03b6d9b1..0730f3f2 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -2,11 +2,19 @@ import os import logging from datetime import datetime +from typing import Any, Optional +from uuid import UUID +import asyncio +from threading import Thread # External Packages from langchain.chat_models import ChatOpenAI from langchain.llms import OpenAI from langchain.schema import ChatMessage +from langchain.callbacks.base import BaseCallbackHandler +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain.callbacks import AsyncIteratorCallbackHandler +from langchain.callbacks.base import BaseCallbackManager, AsyncCallbackHandler import openai import tiktoken from tenacity import ( @@ -20,12 +28,43 @@ from tenacity import ( # Internal Packages from khoj.utils.helpers import merge_dicts +import queue logger = logging.getLogger(__name__) max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192} +class ThreadedGenerator: + def __init__(self): + self.queue = queue.Queue() + + def __iter__(self): + return self + + def __next__(self): + item = self.queue.get() + if item is StopIteration: + raise item + return item + + def send(self, data): + self.queue.put(data) + + def close(self): + self.queue.put(StopIteration) + + +class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler): + def __init__(self, gen: ThreadedGenerator): + super().__init__() + self.gen = gen + + def on_llm_new_token(self, token: str, **kwargs) -> Any: + logger.debug(f"New Token: {token}") + self.gen.send(token) + + @retry( retry=( retry_if_exception_type(openai.error.Timeout) @@ -63,14 +102,28 @@ def completion_with_backoff(**kwargs): reraise=True, ) def chat_completion_with_backoff(messages, model_name, temperature, openai_api_key=None): + g = ThreadedGenerator() + t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key)) + t.start() + return g + + +def llm_thread(g, messages, model_name, temperature, openai_api_key=None): + callback_handler = StreamingChatCallbackHandler(g) chat = ChatOpenAI( + streaming=True, + verbose=True, + callback_manager=BaseCallbackManager([callback_handler]), model_name=model_name, temperature=temperature, openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), request_timeout=20, max_retries=1, ) - return chat(messages).content + + chat(messages=messages) + + g.close() def generate_chatml_messages_with_context( diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index a32f420b..2d77d2ef 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -34,6 +34,7 @@ from khoj.utils.rawconfig import ( from khoj.utils.state import SearchType from khoj.utils import state, constants from khoj.utils.yaml import save_config_to_file_updated_state +from fastapi.responses import StreamingResponse # Initialize Router api = APIRouter() @@ -393,8 +394,8 @@ def update( return {"status": "ok", "message": "khoj reloaded"} -@api.get("/chat") -async def chat( +@api.get("/chat/init") +def chat_init( request: Request, q: Optional[str] = None, client: Optional[str] = None, @@ -411,13 +412,52 @@ async def chat( status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat." ) + # Load Conversation History + meta_log = state.processor_config.conversation.meta_log + + user_state = { + "client_host": request.client.host, + "user_agent": user_agent or "unknown", + "referer": referer or "unknown", + "host": host or "unknown", + } + + state.telemetry += [ + log_telemetry( + telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state + ) + ] + + # If user query is empty, return chat history + if not q: + return {"status": "ok", "response": meta_log.get("chat", [])} + + +@api.get("/chat", response_class=StreamingResponse) +async def chat( + request: Request, + q: Optional[str] = None, + client: Optional[str] = None, + user_agent: Optional[str] = Header(None), + referer: Optional[str] = Header(None), + host: Optional[str] = Header(None), +) -> StreamingResponse: + if ( + state.processor_config is None + or state.processor_config.conversation is None + or state.processor_config.conversation.openai_api_key is None + ): + raise HTTPException( + status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat." + ) + # Load Conversation History chat_session = state.processor_config.conversation.chat_session meta_log = state.processor_config.conversation.meta_log # If user query is empty, return chat history if not q: - return {"status": "ok", "response": meta_log.get("chat", [])} + return StreamingResponse(None) # Initialize Variables api_key = state.processor_config.conversation.openai_api_key @@ -446,24 +486,6 @@ async def chat( conversation_type = "notes" if compiled_references else "general" logger.debug(f"Conversation Type: {conversation_type}") - try: - with timer("Generating chat response took", logger): - gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key) - status = "ok" - except Exception as e: - gpt_response = str(e) - status = "error" - - # Update Conversation History - state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) - state.processor_config.conversation.meta_log["chat"] = message_to_log( - q, - gpt_response, - user_message_metadata={"created": user_message_time}, - khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}}, - conversation_log=meta_log.get("chat", []), - ) - user_state = { "client_host": request.client.host, "user_agent": user_agent or "unknown", @@ -477,4 +499,20 @@ async def chat( ) ] - return {"status": status, "response": gpt_response, "context": compiled_references} + try: + with timer("Generating chat response took", logger): + gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key) + except Exception as e: + gpt_response = str(e) + + # Update Conversation History + # state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) + # state.processor_config.conversation.meta_log["chat"] = message_to_log( + # q, + # gpt_response, + # user_message_metadata={"created": user_message_time}, + # khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}}, + # conversation_log=meta_log.get("chat", []), + # ) + + return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)