From eb86f6fc428fd6f21ebf8c629aca41c3ea6d8f45 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 30 Sep 2024 03:01:40 -0700 Subject: [PATCH 1/4] Add __str__ func to LocationData class to dedupe location string gen Previously the location string from location data was being generated wherever it was being used. By adding a __str__ representation to LocationData class, we can dedupe and simplify the code to get the location string --- .../processor/conversation/anthropic/anthropic_chat.py | 5 ++--- src/khoj/processor/conversation/google/gemini_chat.py | 5 ++--- src/khoj/processor/conversation/offline/chat_model.py | 5 ++--- src/khoj/processor/conversation/openai/gpt.py | 5 ++--- src/khoj/routers/helpers.py | 7 +++---- src/khoj/utils/rawconfig.py | 10 ++++++++++ 6 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 94d8df03..0309d29b 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -32,7 +32,7 @@ def extract_questions_anthropic( Infer search queries to retrieve relevant notes to answer user query """ # Extract Past User Message and Inferred Questions from Conversation Log - location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log @@ -158,8 +158,7 @@ def converse_anthropic( ) if location_data: - location = f"{location_data.city}, {location_data.region}, {location_data.country}" - location_prompt = prompts.user_location.format(location=location) + location_prompt = prompts.user_location.format(location=f"{location_data}") system_prompt = f"{system_prompt}\n{location_prompt}" if user_name: diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 64fa5d66..a2ccc87b 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -33,7 +33,7 @@ def extract_questions_gemini( Infer search queries to retrieve relevant notes to answer user query """ # Extract Past User Message and Inferred Questions from Conversation Log - location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log @@ -163,8 +163,7 @@ def converse_gemini( ) if location_data: - location = f"{location_data.city}, {location_data.region}, {location_data.country}" - location_prompt = prompts.user_location.format(location=location) + location_prompt = prompts.user_location.format(location=f"{location_data}") system_prompt = f"{system_prompt}\n{location_prompt}" if user_name: diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index febe3786..a8229332 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -46,7 +46,7 @@ def extract_questions_offline( assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) - location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log @@ -171,8 +171,7 @@ def converse_offline( conversation_primer = prompts.query_prompt.format(query=user_query) if location_data: - location = f"{location_data.city}, {location_data.region}, {location_data.country}" - location_prompt = prompts.user_location.format(location=location) + location_prompt = prompts.user_location.format(location=f"{location_data}") system_prompt = f"{system_prompt}\n{location_prompt}" if user_name: diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 90cd4df9..1361e1ae 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -36,7 +36,7 @@ def extract_questions( """ Infer search queries to retrieve relevant notes to answer user query """ - location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log @@ -159,8 +159,7 @@ def converse( ) if location_data: - location = f"{location_data.city}, {location_data.region}, {location_data.country}" - location_prompt = prompts.user_location.format(location=location) + location_prompt = prompts.user_location.format(location=f"{location_data}") system_prompt = f"{system_prompt}\n{location_prompt}" if user_name: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 0f60b100..59d44925 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -369,7 +369,7 @@ async def infer_webpage_urls( """ Infer webpage links from the given query """ - location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" chat_history = construct_chat_history(conversation_history) @@ -405,7 +405,7 @@ async def generate_online_subqueries( """ Generate subqueries from the given query """ - location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" chat_history = construct_chat_history(conversation_history) @@ -535,8 +535,7 @@ async def generate_better_image_prompt( model_type = model_type or TextToImageModelConfig.ModelType.OPENAI if location_data: - location = f"{location_data.city}, {location_data.region}, {location_data.country}" - location_prompt = prompts.user_location.format(location=location) + location_prompt = prompts.user_location.format(location=f"{location_data}") else: location_prompt = "Unknown" diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 79318b01..f0cd3962 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -26,6 +26,16 @@ class LocationData(BaseModel): region: Optional[str] country: Optional[str] + def __str__(self): + parts = [] + if self.city: + parts.append(self.city) + if self.region: + parts.append(self.region) + if self.country: + parts.append(self.country) + return ", ".join(parts) + class FileFilterRequest(BaseModel): filename: str From 1fed842fcca221e02c03f411821d3d7f89491213 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 30 Sep 2024 01:19:56 -0700 Subject: [PATCH 2/4] Localize online search results to user country when location available Get country code to server chat api from i.p location check on clients. Use country code to get country specific online search results via Serper.dev API --- src/interface/desktop/chat.html | 3 +++ src/interface/desktop/shortcut.html | 6 ++++-- src/interface/obsidian/src/chat_view.ts | 3 +++ src/interface/web/app/chat/page.tsx | 3 ++- src/interface/web/app/common/utils.ts | 1 + src/interface/web/app/share/chat/page.tsx | 1 + src/khoj/processor/tools/online_search.py | 10 +++++----- src/khoj/routers/api_chat.py | 6 ++++-- src/khoj/utils/rawconfig.py | 1 + 9 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 3df00efc..cd47dae5 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -60,6 +60,7 @@ let region = null; let city = null; let countryName = null; + let countryCode = null; let timezone = null; let chatMessageState = { newResponseTextEl: null, @@ -76,6 +77,7 @@ region = data.region; city = data.city; countryName = data.country_name; + countryCode = data.country_code; timezone = data.timezone; }) .catch(err => { @@ -157,6 +159,7 @@ ...(!!city && { city: city }), ...(!!region && { region: region }), ...(!!countryName && { country: countryName }), + ...(!!countryCode && { country_code: countryCode }), ...(!!timezone && { timezone: timezone }), }; diff --git a/src/interface/desktop/shortcut.html b/src/interface/desktop/shortcut.html index 86e5d906..4b07b8ea 100644 --- a/src/interface/desktop/shortcut.html +++ b/src/interface/desktop/shortcut.html @@ -308,18 +308,19 @@