mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
Initial code with chat streaming working (warning: messy code)
This commit is contained in:
@@ -64,14 +64,47 @@
|
|||||||
// Generate backend API URL to execute query
|
// Generate backend API URL to execute query
|
||||||
let url = `/api/chat?q=${encodeURIComponent(query)}&client=web`;
|
let url = `/api/chat?q=${encodeURIComponent(query)}&client=web`;
|
||||||
|
|
||||||
// Call specified Khoj API
|
let chat_body = document.getElementById("chat-body");
|
||||||
|
let new_response = document.createElement("div");
|
||||||
|
new_response.classList.add("chat-message", "khoj");
|
||||||
|
new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
|
||||||
|
chat_body.appendChild(new_response);
|
||||||
|
|
||||||
|
let new_response_text = document.createElement("div");
|
||||||
|
new_response_text.classList.add("chat-message-text", "khoj");
|
||||||
|
new_response.appendChild(new_response_text);
|
||||||
|
|
||||||
|
// Call specified Khoj API which returns a streamed response of type text/plain
|
||||||
fetch(url)
|
fetch(url)
|
||||||
.then(response => response.json())
|
.then(response => {
|
||||||
.then(data => {
|
console.log(response);
|
||||||
// Render message by Khoj to chat body
|
const reader = response.body.getReader();
|
||||||
console.log(data.response);
|
const decoder = new TextDecoder();
|
||||||
renderMessageWithReference(data.response, "khoj", data.context);
|
|
||||||
|
function readStream() {
|
||||||
|
reader.read().then(({ done, value }) => {
|
||||||
|
if (done) {
|
||||||
|
console.log("Stream complete");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
|
new_response_text.innerHTML += chunk;
|
||||||
|
console.log(`Received ${chunk.length} bytes of data`);
|
||||||
|
console.log(`Chunk: ${chunk}`);
|
||||||
|
readStream();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
readStream();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
|
// fetch(url)
|
||||||
|
// .then(data => {
|
||||||
|
// // Render message by Khoj to chat body
|
||||||
|
// console.log(data.response);
|
||||||
|
// renderMessageWithReference(data.response, "khoj", data.context);
|
||||||
|
// });
|
||||||
}
|
}
|
||||||
|
|
||||||
function incrementalChat(event) {
|
function incrementalChat(event) {
|
||||||
@@ -82,7 +115,7 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
window.onload = function () {
|
window.onload = function () {
|
||||||
fetch('/api/chat?client=web')
|
fetch('/api/chat/init?client=web')
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
if (data.detail) {
|
if (data.detail) {
|
||||||
|
|||||||
@@ -170,12 +170,18 @@ def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo",
|
|||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
logger.debug(f"Conversation Context for GPT: {messages}")
|
logger.debug(f"Conversation Context for GPT: {messages}")
|
||||||
response = chat_completion_with_backoff(
|
return chat_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
openai_api_key=api_key,
|
openai_api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# async for tokens in chat_completion_with_backoff(
|
||||||
return response.strip(empty_escape_sequences)
|
# messages=messages,
|
||||||
|
# model_name=model,
|
||||||
|
# temperature=temperature,
|
||||||
|
# openai_api_key=api_key,
|
||||||
|
# ):
|
||||||
|
# logger.info(f"Tokens from GPT: {tokens}")
|
||||||
|
# yield tokens
|
||||||
|
|||||||
@@ -2,11 +2,19 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Any, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
import asyncio
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain.llms import OpenAI
|
from langchain.llms import OpenAI
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager, AsyncCallbackHandler
|
||||||
import openai
|
import openai
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
@@ -20,12 +28,43 @@ from tenacity import (
|
|||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.helpers import merge_dicts
|
from khoj.utils.helpers import merge_dicts
|
||||||
|
import queue
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
|
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadedGenerator:
|
||||||
|
def __init__(self):
|
||||||
|
self.queue = queue.Queue()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
item = self.queue.get()
|
||||||
|
if item is StopIteration:
|
||||||
|
raise item
|
||||||
|
return item
|
||||||
|
|
||||||
|
def send(self, data):
|
||||||
|
self.queue.put(data)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.queue.put(StopIteration)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
||||||
|
def __init__(self, gen: ThreadedGenerator):
|
||||||
|
super().__init__()
|
||||||
|
self.gen = gen
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
||||||
|
logger.debug(f"New Token: {token}")
|
||||||
|
self.gen.send(token)
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=(
|
retry=(
|
||||||
retry_if_exception_type(openai.error.Timeout)
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
@@ -63,14 +102,28 @@ def completion_with_backoff(**kwargs):
|
|||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def chat_completion_with_backoff(messages, model_name, temperature, openai_api_key=None):
|
def chat_completion_with_backoff(messages, model_name, temperature, openai_api_key=None):
|
||||||
|
g = ThreadedGenerator()
|
||||||
|
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
|
||||||
|
t.start()
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
||||||
|
callback_handler = StreamingChatCallbackHandler(g)
|
||||||
chat = ChatOpenAI(
|
chat = ChatOpenAI(
|
||||||
|
streaming=True,
|
||||||
|
verbose=True,
|
||||||
|
callback_manager=BaseCallbackManager([callback_handler]),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||||
request_timeout=20,
|
request_timeout=20,
|
||||||
max_retries=1,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
return chat(messages).content
|
|
||||||
|
chat(messages=messages)
|
||||||
|
|
||||||
|
g.close()
|
||||||
|
|
||||||
|
|
||||||
def generate_chatml_messages_with_context(
|
def generate_chatml_messages_with_context(
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from khoj.utils.rawconfig import (
|
|||||||
from khoj.utils.state import SearchType
|
from khoj.utils.state import SearchType
|
||||||
from khoj.utils import state, constants
|
from khoj.utils import state, constants
|
||||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
# Initialize Router
|
# Initialize Router
|
||||||
api = APIRouter()
|
api = APIRouter()
|
||||||
@@ -393,8 +394,8 @@ def update(
|
|||||||
return {"status": "ok", "message": "khoj reloaded"}
|
return {"status": "ok", "message": "khoj reloaded"}
|
||||||
|
|
||||||
|
|
||||||
@api.get("/chat")
|
@api.get("/chat/init")
|
||||||
async def chat(
|
def chat_init(
|
||||||
request: Request,
|
request: Request,
|
||||||
q: Optional[str] = None,
|
q: Optional[str] = None,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
@@ -411,13 +412,52 @@ async def chat(
|
|||||||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Load Conversation History
|
||||||
|
meta_log = state.processor_config.conversation.meta_log
|
||||||
|
|
||||||
|
user_state = {
|
||||||
|
"client_host": request.client.host,
|
||||||
|
"user_agent": user_agent or "unknown",
|
||||||
|
"referer": referer or "unknown",
|
||||||
|
"host": host or "unknown",
|
||||||
|
}
|
||||||
|
|
||||||
|
state.telemetry += [
|
||||||
|
log_telemetry(
|
||||||
|
telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# If user query is empty, return chat history
|
||||||
|
if not q:
|
||||||
|
return {"status": "ok", "response": meta_log.get("chat", [])}
|
||||||
|
|
||||||
|
|
||||||
|
@api.get("/chat", response_class=StreamingResponse)
|
||||||
|
async def chat(
|
||||||
|
request: Request,
|
||||||
|
q: Optional[str] = None,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
user_agent: Optional[str] = Header(None),
|
||||||
|
referer: Optional[str] = Header(None),
|
||||||
|
host: Optional[str] = Header(None),
|
||||||
|
) -> StreamingResponse:
|
||||||
|
if (
|
||||||
|
state.processor_config is None
|
||||||
|
or state.processor_config.conversation is None
|
||||||
|
or state.processor_config.conversation.openai_api_key is None
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
||||||
|
)
|
||||||
|
|
||||||
# Load Conversation History
|
# Load Conversation History
|
||||||
chat_session = state.processor_config.conversation.chat_session
|
chat_session = state.processor_config.conversation.chat_session
|
||||||
meta_log = state.processor_config.conversation.meta_log
|
meta_log = state.processor_config.conversation.meta_log
|
||||||
|
|
||||||
# If user query is empty, return chat history
|
# If user query is empty, return chat history
|
||||||
if not q:
|
if not q:
|
||||||
return {"status": "ok", "response": meta_log.get("chat", [])}
|
return StreamingResponse(None)
|
||||||
|
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
api_key = state.processor_config.conversation.openai_api_key
|
api_key = state.processor_config.conversation.openai_api_key
|
||||||
@@ -446,24 +486,6 @@ async def chat(
|
|||||||
conversation_type = "notes" if compiled_references else "general"
|
conversation_type = "notes" if compiled_references else "general"
|
||||||
logger.debug(f"Conversation Type: {conversation_type}")
|
logger.debug(f"Conversation Type: {conversation_type}")
|
||||||
|
|
||||||
try:
|
|
||||||
with timer("Generating chat response took", logger):
|
|
||||||
gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key)
|
|
||||||
status = "ok"
|
|
||||||
except Exception as e:
|
|
||||||
gpt_response = str(e)
|
|
||||||
status = "error"
|
|
||||||
|
|
||||||
# Update Conversation History
|
|
||||||
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
|
||||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
|
||||||
q,
|
|
||||||
gpt_response,
|
|
||||||
user_message_metadata={"created": user_message_time},
|
|
||||||
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
|
||||||
conversation_log=meta_log.get("chat", []),
|
|
||||||
)
|
|
||||||
|
|
||||||
user_state = {
|
user_state = {
|
||||||
"client_host": request.client.host,
|
"client_host": request.client.host,
|
||||||
"user_agent": user_agent or "unknown",
|
"user_agent": user_agent or "unknown",
|
||||||
@@ -477,4 +499,20 @@ async def chat(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
return {"status": status, "response": gpt_response, "context": compiled_references}
|
try:
|
||||||
|
with timer("Generating chat response took", logger):
|
||||||
|
gpt_response = converse(compiled_references, q, meta_log, model=chat_model, api_key=api_key)
|
||||||
|
except Exception as e:
|
||||||
|
gpt_response = str(e)
|
||||||
|
|
||||||
|
# Update Conversation History
|
||||||
|
# state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
||||||
|
# state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||||
|
# q,
|
||||||
|
# gpt_response,
|
||||||
|
# user_message_metadata={"created": user_message_time},
|
||||||
|
# khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||||
|
# conversation_log=meta_log.get("chat", []),
|
||||||
|
# )
|
||||||
|
|
||||||
|
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
|
||||||
|
|||||||
Reference in New Issue
Block a user