From 79b1b1d35018ad92d574fb4d6608c18838555620 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Tue, 4 Jul 2023 17:33:52 -0700 Subject: [PATCH] Save streamed chat conversations via partial function passed to the ThreadGenerator --- src/khoj/processor/conversation/gpt.py | 11 +++++- src/khoj/processor/conversation/utils.py | 14 +++++--- src/khoj/routers/api.py | 46 +++++++++++++++--------- 3 files changed, 50 insertions(+), 21 deletions(-) diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py index 1c993c37..14dcf87b 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/gpt.py @@ -144,7 +144,15 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1 return json.loads(response.strip(empty_escape_sequences)) -def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo", api_key=None, temperature=0.2): +def converse( + references, + user_query, + conversation_log={}, + model="gpt-3.5-turbo", + api_key=None, + temperature=0.2, + completion_func=None, +): """ Converse with user using OpenAI's ChatGPT """ @@ -176,6 +184,7 @@ def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo", model_name=model, temperature=temperature, openai_api_key=api_key, + completion_func=completion_func, ) # async for tokens in chat_completion_with_backoff( diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index f2a87b3c..4305999f 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -37,9 +37,11 @@ max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192} class ThreadedGenerator: - def __init__(self, compiled_references): + def __init__(self, compiled_references, completion_func=None): self.queue = queue.Queue() self.compiled_references = compiled_references + self.completion_func = completion_func + self.response = "" def __iter__(self): return self @@ -47,10 +49,13 @@ class ThreadedGenerator: def __next__(self): item = self.queue.get() if item is StopIteration: + if self.completion_func: + self.completion_func(gpt_response=self.response) raise StopIteration return item def send(self, data): + self.response += data self.queue.put(data) def close(self): @@ -65,7 +70,6 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler): self.gen = gen def on_llm_new_token(self, token: str, **kwargs) -> Any: - logger.debug(f"New Token: {token}") self.gen.send(token) @@ -105,8 +109,10 @@ def completion_with_backoff(**kwargs): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) -def chat_completion_with_backoff(messages, compiled_references, model_name, temperature, openai_api_key=None): - g = ThreadedGenerator(compiled_references) +def chat_completion_with_backoff( + messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None +): + g = ThreadedGenerator(compiled_references, completion_func=completion_func) t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key)) t.start() return g diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 3acbd62a..d1fbbdb5 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -6,6 +6,7 @@ import yaml import logging from datetime import datetime from typing import List, Optional, Union +from functools import partial # External Packages from fastapi import APIRouter, HTTPException, Header, Request @@ -442,6 +443,24 @@ async def chat( referer: Optional[str] = Header(None), host: Optional[str] = Header(None), ) -> StreamingResponse: + def _save_to_conversation_log( + q: str, + gpt_response: str, + user_message_time: str, + compiled_references: List[str], + inferred_queries: List[str], + chat_session: str, + meta_log, + ): + state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) + state.processor_config.conversation.meta_log["chat"] = message_to_log( + q, + gpt_response, + user_message_metadata={"created": user_message_time}, + khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}}, + conversation_log=meta_log.get("chat", []), + ) + if ( state.processor_config is None or state.processor_config.conversation is None @@ -501,26 +520,21 @@ async def chat( try: with timer("Generating chat response took", logger): - gpt_response = converse( - compiled_references, + partial_completion = partial( + _save_to_conversation_log, q, - meta_log, - model=chat_model, - api_key=api_key, - chat_session=chat_session, + user_message_time=user_message_time, + compiled_references=compiled_references, inferred_queries=inferred_queries, + chat_session=chat_session, + meta_log=meta_log, ) + + gpt_response = converse( + compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion + ) + except Exception as e: gpt_response = str(e) - # Update Conversation History - # state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) - # state.processor_config.conversation.meta_log["chat"] = message_to_log( - # q, - # gpt_response, - # user_message_metadata={"created": user_message_time}, - # khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}}, - # conversation_log=meta_log.get("chat", []), - # ) - return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)