diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py index b9f14884..f9f6aca1 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/gpt.py @@ -35,8 +35,7 @@ def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=50 ) # Extract, Clean Message from GPT's Response - story = response["choices"][0]["text"] - return str(story).replace("\n\n", "") + return str(response).replace("\n\n", "") def summarize(text, summary_type, model, user_query=None, api_key=None, temperature=0.5, max_tokens=200): @@ -62,8 +61,7 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat ) # Extract, Clean Message from GPT's Response - story = response["choices"][0]["text"] - return str(story).replace("\n\n", "") + return str(response).replace("\n\n", "") def extract_questions(text, model="text-davinci-003", conversation_log={}, api_key=None, temperature=0, max_tokens=100): @@ -106,17 +104,16 @@ def extract_questions(text, model="text-davinci-003", conversation_log={}, api_k ) # Extract, Clean Message from GPT's Response - response_text = response["choices"][0]["text"] try: questions = json.loads( # Clean response to increase likelihood of valid JSON. E.g replace ' with " to enclose strings - response_text.strip(empty_escape_sequences) + response.strip(empty_escape_sequences) .replace("['", '["') .replace("']", '"]') .replace("', '", '", "') ) except json.decoder.JSONDecodeError: - logger.warn(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response_text}") + logger.warn(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}") questions = [text] logger.debug(f"Extracted Questions by GPT: {questions}") return questions @@ -144,8 +141,7 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1 ) # Extract, Clean Message from GPT's Response - story = str(response["choices"][0]["text"]) - return json.loads(story.strip(empty_escape_sequences)) + return json.loads(response.strip(empty_escape_sequences)) def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo", api_key=None, temperature=0.2): diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 3e3d38e7..e18f7bbd 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -5,6 +5,7 @@ from datetime import datetime # External Packages from langchain.chat_models import ChatOpenAI +from langchain.llms import OpenAI from langchain.schema import ChatMessage import openai import tiktoken @@ -39,8 +40,10 @@ max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192} reraise=True, ) def completion_with_backoff(**kwargs): - openai.api_key = kwargs["api_key"] if kwargs.get("api_key") else os.getenv("OPENAI_API_KEY") - return openai.Completion.create(**kwargs, request_timeout=60) + prompt = kwargs.pop("prompt") + kwargs["openai_api_key"] = kwargs["api_key"] if kwargs.get("api_key") else os.getenv("OPENAI_API_KEY") + llm = OpenAI(**kwargs, request_timeout=60) + return llm(prompt) @retry(