mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-10 05:39:11 +00:00
Save streamed chat conversations via partial function passed to the ThreadGenerator
This commit is contained in:
@@ -144,7 +144,15 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1
|
|||||||
return json.loads(response.strip(empty_escape_sequences))
|
return json.loads(response.strip(empty_escape_sequences))
|
||||||
|
|
||||||
|
|
||||||
def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo", api_key=None, temperature=0.2):
|
def converse(
|
||||||
|
references,
|
||||||
|
user_query,
|
||||||
|
conversation_log={},
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
api_key=None,
|
||||||
|
temperature=0.2,
|
||||||
|
completion_func=None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Converse with user using OpenAI's ChatGPT
|
Converse with user using OpenAI's ChatGPT
|
||||||
"""
|
"""
|
||||||
@@ -176,6 +184,7 @@ def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo",
|
|||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
openai_api_key=api_key,
|
openai_api_key=api_key,
|
||||||
|
completion_func=completion_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
# async for tokens in chat_completion_with_backoff(
|
# async for tokens in chat_completion_with_backoff(
|
||||||
|
|||||||
@@ -37,9 +37,11 @@ max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
|
|||||||
|
|
||||||
|
|
||||||
class ThreadedGenerator:
|
class ThreadedGenerator:
|
||||||
def __init__(self, compiled_references):
|
def __init__(self, compiled_references, completion_func=None):
|
||||||
self.queue = queue.Queue()
|
self.queue = queue.Queue()
|
||||||
self.compiled_references = compiled_references
|
self.compiled_references = compiled_references
|
||||||
|
self.completion_func = completion_func
|
||||||
|
self.response = ""
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
@@ -47,10 +49,13 @@ class ThreadedGenerator:
|
|||||||
def __next__(self):
|
def __next__(self):
|
||||||
item = self.queue.get()
|
item = self.queue.get()
|
||||||
if item is StopIteration:
|
if item is StopIteration:
|
||||||
|
if self.completion_func:
|
||||||
|
self.completion_func(gpt_response=self.response)
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def send(self, data):
|
def send(self, data):
|
||||||
|
self.response += data
|
||||||
self.queue.put(data)
|
self.queue.put(data)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
@@ -65,7 +70,6 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
|||||||
self.gen = gen
|
self.gen = gen
|
||||||
|
|
||||||
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
||||||
logger.debug(f"New Token: {token}")
|
|
||||||
self.gen.send(token)
|
self.gen.send(token)
|
||||||
|
|
||||||
|
|
||||||
@@ -105,8 +109,10 @@ def completion_with_backoff(**kwargs):
|
|||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def chat_completion_with_backoff(messages, compiled_references, model_name, temperature, openai_api_key=None):
|
def chat_completion_with_backoff(
|
||||||
g = ThreadedGenerator(compiled_references)
|
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
|
||||||
|
):
|
||||||
|
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
|
||||||
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
|
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import yaml
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter, HTTPException, Header, Request
|
from fastapi import APIRouter, HTTPException, Header, Request
|
||||||
@@ -442,6 +443,24 @@ async def chat(
|
|||||||
referer: Optional[str] = Header(None),
|
referer: Optional[str] = Header(None),
|
||||||
host: Optional[str] = Header(None),
|
host: Optional[str] = Header(None),
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
|
def _save_to_conversation_log(
|
||||||
|
q: str,
|
||||||
|
gpt_response: str,
|
||||||
|
user_message_time: str,
|
||||||
|
compiled_references: List[str],
|
||||||
|
inferred_queries: List[str],
|
||||||
|
chat_session: str,
|
||||||
|
meta_log,
|
||||||
|
):
|
||||||
|
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
||||||
|
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||||
|
q,
|
||||||
|
gpt_response,
|
||||||
|
user_message_metadata={"created": user_message_time},
|
||||||
|
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||||
|
conversation_log=meta_log.get("chat", []),
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
state.processor_config is None
|
state.processor_config is None
|
||||||
or state.processor_config.conversation is None
|
or state.processor_config.conversation is None
|
||||||
@@ -501,26 +520,21 @@ async def chat(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with timer("Generating chat response took", logger):
|
with timer("Generating chat response took", logger):
|
||||||
gpt_response = converse(
|
partial_completion = partial(
|
||||||
compiled_references,
|
_save_to_conversation_log,
|
||||||
q,
|
q,
|
||||||
meta_log,
|
user_message_time=user_message_time,
|
||||||
model=chat_model,
|
compiled_references=compiled_references,
|
||||||
api_key=api_key,
|
|
||||||
chat_session=chat_session,
|
|
||||||
inferred_queries=inferred_queries,
|
inferred_queries=inferred_queries,
|
||||||
|
chat_session=chat_session,
|
||||||
|
meta_log=meta_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gpt_response = converse(
|
||||||
|
compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
gpt_response = str(e)
|
gpt_response = str(e)
|
||||||
|
|
||||||
# Update Conversation History
|
|
||||||
# state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
|
||||||
# state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
|
||||||
# q,
|
|
||||||
# gpt_response,
|
|
||||||
# user_message_metadata={"created": user_message_time},
|
|
||||||
# khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
|
||||||
# conversation_log=meta_log.get("chat", []),
|
|
||||||
# )
|
|
||||||
|
|
||||||
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
|
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
|
||||||
|
|||||||
Reference in New Issue
Block a user