diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py index f9f6aca1..40cccd85 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/gpt.py @@ -27,7 +27,7 @@ def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=50 logger.debug(f"Prompt for GPT: {prompt}") response = completion_with_backoff( prompt=prompt, - model=model, + model_name=model, temperature=temperature, max_tokens=max_tokens, stop='"""', @@ -52,7 +52,7 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat logger.debug(f"Prompt for GPT: {prompt}") response = completion_with_backoff( prompt=prompt, - model=model, + model_name=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, @@ -96,7 +96,7 @@ def extract_questions(text, model="text-davinci-003", conversation_log={}, api_k # Get Response from GPT response = completion_with_backoff( prompt=prompt, - model=model, + model_name=model, temperature=temperature, max_tokens=max_tokens, stop=["A: ", "\n"], @@ -132,7 +132,7 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1 logger.debug(f"Prompt for GPT: {prompt}") response = completion_with_backoff( prompt=prompt, - model=model, + model_name=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, @@ -174,9 +174,9 @@ def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo", logger.debug(f"Conversation Context for GPT: {messages}") response = chat_completion_with_backoff( messages=messages, - model=model, + model_name=model, temperature=temperature, - api_key=api_key, + openai_api_key=api_key, ) # Extract, Clean Message from GPT's Response diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index f058b877..a95f16f5 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -41,7 +41,8 @@ max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192} ) def completion_with_backoff(**kwargs): prompt = kwargs.pop("prompt") - kwargs["openai_api_key"] = kwargs["api_key"] if kwargs.get("api_key") else os.getenv("OPENAI_API_KEY") + if "openai_api_key" not in kwargs: + kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") llm = OpenAI(**kwargs, request_timeout=10, max_retries=1) return llm(prompt) @@ -59,12 +60,11 @@ def completion_with_backoff(**kwargs): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) -def chat_completion_with_backoff(messages, model, temperature, **kwargs): - openai_api_key = kwargs["api_key"] if kwargs.get("api_key") else os.getenv("OPENAI_API_KEY") +def chat_completion_with_backoff(messages, model_name, temperature, openai_api_key=None): chat = ChatOpenAI( - model_name=model, + model_name=model_name, temperature=temperature, - openai_api_key=openai_api_key, + openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), request_timeout=10, max_retries=1, )