Add __str__ func to LocationData class to dedupe location string gen

Previously the location string from location data was being generated
wherever it was being used.

By adding a __str__ representation to LocationData class, we can
dedupe and simplify the code to get the location string
This commit is contained in:
Debanjum Singh Solanky
2024-09-30 03:01:40 -07:00
parent d21a4e73a0
commit eb86f6fc42
6 changed files with 21 additions and 16 deletions

View File

@@ -32,7 +32,7 @@ def extract_questions_anthropic(
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
""" """
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
@@ -158,8 +158,7 @@ def converse_anthropic(
) )
if location_data: if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}" location_prompt = prompts.user_location.format(location=f"{location_data}")
location_prompt = prompts.user_location.format(location=location)
system_prompt = f"{system_prompt}\n{location_prompt}" system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name: if user_name:

View File

@@ -33,7 +33,7 @@ def extract_questions_gemini(
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
""" """
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
@@ -163,8 +163,7 @@ def converse_gemini(
) )
if location_data: if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}" location_prompt = prompts.user_location.format(location=f"{location_data}")
location_prompt = prompts.user_location.format(location=location)
system_prompt = f"{system_prompt}\n{location_prompt}" system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name: if user_name:

View File

@@ -46,7 +46,7 @@ def extract_questions_offline(
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
@@ -171,8 +171,7 @@ def converse_offline(
conversation_primer = prompts.query_prompt.format(query=user_query) conversation_primer = prompts.query_prompt.format(query=user_query)
if location_data: if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}" location_prompt = prompts.user_location.format(location=f"{location_data}")
location_prompt = prompts.user_location.format(location=location)
system_prompt = f"{system_prompt}\n{location_prompt}" system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name: if user_name:

View File

@@ -36,7 +36,7 @@ def extract_questions(
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
""" """
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
@@ -159,8 +159,7 @@ def converse(
) )
if location_data: if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}" location_prompt = prompts.user_location.format(location=f"{location_data}")
location_prompt = prompts.user_location.format(location=location)
system_prompt = f"{system_prompt}\n{location_prompt}" system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name: if user_name:

View File

@@ -369,7 +369,7 @@ async def infer_webpage_urls(
""" """
Infer webpage links from the given query Infer webpage links from the given query
""" """
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
@@ -405,7 +405,7 @@ async def generate_online_subqueries(
""" """
Generate subqueries from the given query Generate subqueries from the given query
""" """
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
@@ -535,8 +535,7 @@ async def generate_better_image_prompt(
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
if location_data: if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}" location_prompt = prompts.user_location.format(location=f"{location_data}")
location_prompt = prompts.user_location.format(location=location)
else: else:
location_prompt = "Unknown" location_prompt = "Unknown"

View File

@@ -26,6 +26,16 @@ class LocationData(BaseModel):
region: Optional[str] region: Optional[str]
country: Optional[str] country: Optional[str]
def __str__(self):
parts = []
if self.city:
parts.append(self.city)
if self.region:
parts.append(self.region)
if self.country:
parts.append(self.country)
return ", ".join(parts)
class FileFilterRequest(BaseModel): class FileFilterRequest(BaseModel):
filename: str filename: str