Initial code with chat streaming working (warning: messy code)

This commit is contained in:
sabaimran
2023-07-04 10:14:39 -07:00
parent 89354def9b
commit 8f491d72de
4 changed files with 163 additions and 33 deletions

View File

@@ -64,14 +64,47 @@
// Generate backend API URL to execute query // Generate backend API URL to execute query
let url = `/api/chat?q=${encodeURIComponent(query)}&client=web`; let url = `/api/chat?q=${encodeURIComponent(query)}&client=web`;
// Call specified Khoj API let chat_body = document.getElementById("chat-body");
let new_response = document.createElement("div");
new_response.classList.add("chat-message", "khoj");
new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
chat_body.appendChild(new_response);
let new_response_text = document.createElement("div");
new_response_text.classList.add("chat-message-text", "khoj");
new_response.appendChild(new_response_text);
// Call specified Khoj API which returns a streamed response of type text/plain
fetch(url) fetch(url)
.then(response => response.json()) .then(response => {
.then(data => { console.log(response);
// Render message by Khoj to chat body const reader = response.body.getReader();
console.log(data.response); const decoder = new TextDecoder();
renderMessageWithReference(data.response, "khoj", data.context);
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
console.log("Stream complete");
return;
}
const chunk = decoder.decode(value, { stream: true });
new_response_text.innerHTML += chunk;
console.log(`Received ${chunk.length} bytes of data`);
console.log(`Chunk: ${chunk}`);
readStream();
});
}
readStream();
}); });
// fetch(url)
// .then(data => {
// // Render message by Khoj to chat body
// console.log(data.response);
// renderMessageWithReference(data.response, "khoj", data.context);
// });
} }
function incrementalChat(event) { function incrementalChat(event) {
@@ -82,7 +115,7 @@
} }
window.onload = function () { window.onload = function () {
fetch('/api/chat?client=web') fetch('/api/chat/init?client=web')
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
if (data.detail) { if (data.detail) {

View File

@@ -170,12 +170,18 @@ def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo",
# Get Response from GPT # Get Response from GPT
logger.debug(f"Conversation Context for GPT: {messages}") logger.debug(f"Conversation Context for GPT: {messages}")
response = chat_completion_with_backoff( return chat_completion_with_backoff(
messages=messages, messages=messages,
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
openai_api_key=api_key, openai_api_key=api_key,
) )
# Extract, Clean Message from GPT's Response # async for tokens in chat_completion_with_backoff(
return response.strip(empty_escape_sequences) # messages=messages,
# model_name=model,
# temperature=temperature,
# openai_api_key=api_key,
# ):
# logger.info(f"Tokens from GPT: {tokens}")
# yield tokens

View File

@@ -2,11 +2,19 @@
import os import os
import logging import logging
from datetime import datetime from datetime import datetime
from typing import Any, Optional
from uuid import UUID
import asyncio
from threading import Thread
# External Packages # External Packages
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI from langchain.llms import OpenAI
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.callbacks.base import BaseCallbackManager, AsyncCallbackHandler
import openai import openai
import tiktoken import tiktoken
from tenacity import ( from tenacity import (
@@ -20,12 +28,43 @@ from tenacity import (
# Internal Packages # Internal Packages
from khoj.utils.helpers import merge_dicts from khoj.utils.helpers import merge_dicts
import queue
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192} max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
class ThreadedGenerator:
def __init__(self):
self.queue = queue.Queue()
def __iter__(self):
return self
def __next__(self):
item = self.queue.get()
if item is StopIteration:
raise item
return item
def send(self, data):
self.queue.put(data)
def close(self):
self.queue.put(StopIteration)
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
def __init__(self, gen: ThreadedGenerator):
super().__init__()
self.gen = gen
def on_llm_new_token(self, token: str, **kwargs) -> Any:
logger.debug(f"New Token: {token}")
self.gen.send(token)
@retry( @retry(
retry=( retry=(
retry_if_exception_type(openai.error.Timeout) retry_if_exception_type(openai.error.Timeout)
@@ -63,14 +102,28 @@ def completion_with_backoff(**kwargs):
reraise=True, reraise=True,
) )
def chat_completion_with_backoff(messages, model_name, temperature, openai_api_key=None): def chat_completion_with_backoff(messages, model_name, temperature, openai_api_key=None):
g = ThreadedGenerator()
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
t.start()
return g
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
callback_handler = StreamingChatCallbackHandler(g)
chat = ChatOpenAI( chat = ChatOpenAI(
streaming=True,
verbose=True,
callback_manager=BaseCallbackManager([callback_handler]),
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
request_timeout=20, request_timeout=20,
max_retries=1, max_retries=1,
) )
return chat(messages).content
chat(messages=messages)
g.close()
def generate_chatml_messages_with_context( def generate_chatml_messages_with_context(

View File

@@ -34,6 +34,7 @@ from khoj.utils.rawconfig import (
from khoj.utils.state import SearchType from khoj.utils.state import SearchType
from khoj.utils import state, constants from khoj.utils import state, constants
from khoj.utils.yaml import save_config_to_file_updated_state from khoj.utils.yaml import save_config_to_file_updated_state
from fastapi.responses import StreamingResponse
# Initialize Router # Initialize Router
api = APIRouter() api = APIRouter()
@@ -393,8 +394,8 @@ def update(
return {"status": "ok", "message": "khoj reloaded"} return {"status": "ok", "message": "khoj reloaded"}
@api.get("/chat") @api.get("/chat/init")
async def chat( def chat_init(
request: Request, request: Request,
q: Optional[str] = None, q: Optional[str] = None,
client: Optional[str] = None, client: Optional[str] = None,
@@ -411,13 +412,52 @@ async def chat(
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat." status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
) )
# Load Conversation History
meta_log = state.processor_config.conversation.meta_log
user_state = {
"client_host": request.client.host,
"user_agent": user_agent or "unknown",
"referer": referer or "unknown",
"host": host or "unknown",
}
state.telemetry += [
log_telemetry(
telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state
)
]
# If user query is empty, return chat history
if not q:
return {"status": "ok", "response": meta_log.get("chat", [])}
@api.get("/chat", response_class=StreamingResponse)
async def chat(
request: Request,
q: Optional[str] = None,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
) -> StreamingResponse:
if (
state.processor_config is None
or state.processor_config.conversation is None
or state.processor_config.conversation.openai_api_key is None
):
raise HTTPException(
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
)
# Load Conversation History # Load Conversation History
chat_session = state.processor_config.conversation.chat_session chat_session = state.processor_config.conversation.chat_session
meta_log = state.processor_config.conversation.meta_log meta_log = state.processor_config.conversation.meta_log
# If user query is empty, return chat history # If user query is empty, return chat history
if not q: if not q:
return {"status": "ok", "response": meta_log.get("chat", [])} return StreamingResponse(None)
# Initialize Variables # Initialize Variables
api_key = state.processor_config.conversation.openai_api_key api_key = state.processor_config.conversation.openai_api_key
@@ -446,24 +486,6 @@ async def chat(
conversation_type = "notes" if compiled_references else "general" conversation_type = "notes" if compiled_references else "general"
logger.debug(f"Conversation Type: {conversation_type}") logger.debug(f"Conversation Type: {conversation_type}")
try:
with timer("Generating chat response took", logger):
gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key)
status = "ok"
except Exception as e:
gpt_response = str(e)
status = "error"
# Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q,
gpt_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
conversation_log=meta_log.get("chat", []),
)
user_state = { user_state = {
"client_host": request.client.host, "client_host": request.client.host,
"user_agent": user_agent or "unknown", "user_agent": user_agent or "unknown",
@@ -477,4 +499,20 @@ async def chat(
) )
] ]
return {"status": status, "response": gpt_response, "context": compiled_references} try:
with timer("Generating chat response took", logger):
gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key)
except Exception as e:
gpt_response = str(e)
# Update Conversation History
# state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
# state.processor_config.conversation.meta_log["chat"] = message_to_log(
# q,
# gpt_response,
# user_message_metadata={"created": user_message_time},
# khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
# conversation_log=meta_log.get("chat", []),
# )
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)