From cba371678d7fc7b00d0ca4eccec359d137bdede2 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 10 Nov 2023 22:27:24 -0800 Subject: [PATCH] Stop OpenAI chat from emitting reference notes directly in chat body The Chat models sometime output reference notes directly in the chat body in unformatted form, specifically as Notes:\n['. Prevent that. Reference notes are shown in clean, formatted form anyway --- src/khoj/processor/conversation/openai/gpt.py | 1 + src/khoj/processor/conversation/openai/utils.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 27782573..b86ebc6b 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -134,4 +134,5 @@ def converse( temperature=temperature, openai_api_key=api_key, completion_func=completion_func, + model_kwargs={"stop": ["Notes:\n["]}, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 130532e0..dce72e1f 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -69,15 +69,15 @@ def completion_with_backoff(**kwargs): reraise=True, ) def chat_completion_with_backoff( - messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None + messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None, model_kwargs=None ): g = ThreadedGenerator(compiled_references, completion_func=completion_func) - t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key)) + t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs)) t.start() return g -def llm_thread(g, messages, model_name, temperature, openai_api_key=None): +def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None): callback_handler = StreamingChatCallbackHandler(g) chat = ChatOpenAI( streaming=True, @@ -86,6 +86,7 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None): model_name=model_name, # type: ignore temperature=temperature, openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), + model_kwargs=model_kwargs, request_timeout=20, max_retries=1, client=None,