From eb86f6fc428fd6f21ebf8c629aca41c3ea6d8f45 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 30 Sep 2024 03:01:40 -0700 Subject: [PATCH] Add __str__ func to LocationData class to dedupe location string gen Previously the location string from location data was being generated wherever it was being used. By adding a __str__ representation to LocationData class, we can dedupe and simplify the code to get the location string --- .../processor/conversation/anthropic/anthropic_chat.py | 5 ++--- src/khoj/processor/conversation/google/gemini_chat.py | 5 ++--- src/khoj/processor/conversation/offline/chat_model.py | 5 ++--- src/khoj/processor/conversation/openai/gpt.py | 5 ++--- src/khoj/routers/helpers.py | 7 +++---- src/khoj/utils/rawconfig.py | 10 ++++++++++ 6 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 94d8df03..0309d29b 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -32,7 +32,7 @@ def extract_questions_anthropic( Infer search queries to retrieve relevant notes to answer user query """ # Extract Past User Message and Inferred Questions from Conversation Log - location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log @@ -158,8 +158,7 @@ def converse_anthropic( ) if location_data: - location = f"{location_data.city}, {location_data.region}, {location_data.country}" - location_prompt = prompts.user_location.format(location=location) + location_prompt = prompts.user_location.format(location=f"{location_data}") system_prompt = f"{system_prompt}\n{location_prompt}" if user_name: diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 64fa5d66..a2ccc87b 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -33,7 +33,7 @@ def extract_questions_gemini( Infer search queries to retrieve relevant notes to answer user query """ # Extract Past User Message and Inferred Questions from Conversation Log - location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log @@ -163,8 +163,7 @@ def converse_gemini( ) if location_data: - location = f"{location_data.city}, {location_data.region}, {location_data.country}" - location_prompt = prompts.user_location.format(location=location) + location_prompt = prompts.user_location.format(location=f"{location_data}") system_prompt = f"{system_prompt}\n{location_prompt}" if user_name: diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index febe3786..a8229332 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -46,7 +46,7 @@ def extract_questions_offline( assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) - location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log @@ -171,8 +171,7 @@ def converse_offline( conversation_primer = prompts.query_prompt.format(query=user_query) if location_data: - location = f"{location_data.city}, {location_data.region}, {location_data.country}" - location_prompt = prompts.user_location.format(location=location) + location_prompt = prompts.user_location.format(location=f"{location_data}") system_prompt = f"{system_prompt}\n{location_prompt}" if user_name: diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 90cd4df9..1361e1ae 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -36,7 +36,7 @@ def extract_questions( """ Infer search queries to retrieve relevant notes to answer user query """ - location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log @@ -159,8 +159,7 @@ def converse( ) if location_data: - location = f"{location_data.city}, {location_data.region}, {location_data.country}" - location_prompt = prompts.user_location.format(location=location) + location_prompt = prompts.user_location.format(location=f"{location_data}") system_prompt = f"{system_prompt}\n{location_prompt}" if user_name: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 0f60b100..59d44925 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -369,7 +369,7 @@ async def infer_webpage_urls( """ Infer webpage links from the given query """ - location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" chat_history = construct_chat_history(conversation_history) @@ -405,7 +405,7 @@ async def generate_online_subqueries( """ Generate subqueries from the given query """ - location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" chat_history = construct_chat_history(conversation_history) @@ -535,8 +535,7 @@ async def generate_better_image_prompt( model_type = model_type or TextToImageModelConfig.ModelType.OPENAI if location_data: - location = f"{location_data.city}, {location_data.region}, {location_data.country}" - location_prompt = prompts.user_location.format(location=location) + location_prompt = prompts.user_location.format(location=f"{location_data}") else: location_prompt = "Unknown" diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 79318b01..f0cd3962 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -26,6 +26,16 @@ class LocationData(BaseModel): region: Optional[str] country: Optional[str] + def __str__(self): + parts = [] + if self.city: + parts.append(self.city) + if self.region: + parts.append(self.region) + if self.country: + parts.append(self.country) + return ", ".join(parts) + class FileFilterRequest(BaseModel): filename: str