diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py index 986ffc17..df6d851a 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/gpt.py @@ -78,6 +78,50 @@ Summarize the notes in second person perspective:""" return str(story).replace("\n\n", "") +def extract_questions(text, model="text-davinci-003", api_key=None, temperature=0, max_tokens=100): + """ + Infer search queries to retrieve relevant notes to answer user query + """ + # Initialize Variables + openai.api_key = api_key or os.getenv("OPENAI_API_KEY") + + # Get dates relative to today for prompt creation + today = datetime.today() + current_new_year = today.replace(month=1, day=1) + last_new_year = current_new_year.replace(year=today.year - 1) + + prompt = f""" +You are Khoj, a chat assistant with the ability to search the users notes +What searches, if any, will you need to perform to answer the users question below? Provide search queries as a JSON list of strings +Current Date: {today.strftime("%HH:%MM %A, %Y-%m-%d")} + +Q: How was my trip to Cambodia? + +["My Cambodia trip experience"] + +Q: How are you feeling? + +[] + +Q: What national parks did I go to last year? + +["National park I visited in {last_new_year.strftime("%Y")} dt>=\\"{last_new_year.strftime("%Y-%m-%d")}\\" dt<\\"{current_new_year.strftime("%Y-%m-%d")}\\""] + +Q: Is Bob older than Tom? + +["When was Bob born?", "What is Tom's age?"] + +Q: {text}""" + + # Get Response from GPT + response = openai.Completion.create(prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens) + + # Extract, Clean Message from GPT's Response + questions = json.loads(response["choices"][0]["text"].strip(empty_escape_sequences)) + logger.debug(f"Extracted Questions by GPT: {questions}") + return questions + + def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=100, verbose=0): """ Extract search type from user query using OpenAI's GPT diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 4839df48..2bbb0edf 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -10,7 +10,7 @@ from fastapi import HTTPException # Internal Packages from khoj.configure import configure_processor, configure_search -from khoj.processor.conversation.gpt import converse +from khoj.processor.conversation.gpt import converse, extract_questions from khoj.processor.conversation.utils import message_to_log, message_to_prompt from khoj.search_type import image_search, text_search from khoj.utils.helpers import timer @@ -191,6 +191,7 @@ def update(t: Optional[SearchType] = None, force: Optional[bool] = False): def chat(q: Optional[str] = None): # Initialize Variables api_key = state.processor_config.conversation.openai_api_key + model = state.processor_config.conversation.model # Load Conversation History chat_session = state.processor_config.conversation.chat_session @@ -203,9 +204,14 @@ def chat(q: Optional[str] = None): else: return {"status": "ok", "response": []} - # Collate context for GPT - result_list = search(q, n=2, r=True, score_threshold=0, dedupe=False) - collated_result = "\n\n".join([f"# {item.additional['compiled']}" for item in result_list]) + # Extract search queries from user message + queries = extract_questions(q, model=model, api_key=api_key) + + # Collate search results as context for GPT + result_list = [] + for query in queries: + result_list.extend(search(query, n=2, r=True, score_threshold=0, dedupe=False)) + collated_result = "\n\n".join({f"# {item.additional['compiled']}" for item in result_list}) logger.debug(f"Reference Context:\n{collated_result}") try: