diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 27782573..b86ebc6b 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -134,4 +134,5 @@ def converse( temperature=temperature, openai_api_key=api_key, completion_func=completion_func, + model_kwargs={"stop": ["Notes:\n["]}, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 130532e0..dce72e1f 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -69,15 +69,15 @@ def completion_with_backoff(**kwargs): reraise=True, ) def chat_completion_with_backoff( - messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None + messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None, model_kwargs=None ): g = ThreadedGenerator(compiled_references, completion_func=completion_func) - t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key)) + t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs)) t.start() return g -def llm_thread(g, messages, model_name, temperature, openai_api_key=None): +def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None): callback_handler = StreamingChatCallbackHandler(g) chat = ChatOpenAI( streaming=True, @@ -86,6 +86,7 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None): model_name=model_name, # type: ignore temperature=temperature, openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), + model_kwargs=model_kwargs, request_timeout=20, max_retries=1, client=None,