Add __str__ func to LocationData class to dedupe location string gen

Previously the location string from location data was being generated
wherever it was being used.

By adding a __str__ representation to LocationData class, we can
dedupe and simplify the code to get the location string
This commit is contained in:
Debanjum Singh Solanky
2024-09-30 03:01:40 -07:00
parent d21a4e73a0
commit eb86f6fc42
6 changed files with 21 additions and 16 deletions

View File

@@ -32,7 +32,7 @@ def extract_questions_anthropic(
Infer search queries to retrieve relevant notes to answer user query
"""
# Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
@@ -158,8 +158,7 @@ def converse_anthropic(
)
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
location_prompt = prompts.user_location.format(location=f"{location_data}")
system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name:

View File

@@ -33,7 +33,7 @@ def extract_questions_gemini(
Infer search queries to retrieve relevant notes to answer user query
"""
# Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
@@ -163,8 +163,7 @@ def converse_gemini(
)
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
location_prompt = prompts.user_location.format(location=f"{location_data}")
system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name:

View File

@@ -46,7 +46,7 @@ def extract_questions_offline(
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
@@ -171,8 +171,7 @@ def converse_offline(
conversation_primer = prompts.query_prompt.format(query=user_query)
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
location_prompt = prompts.user_location.format(location=f"{location_data}")
system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name:

View File

@@ -36,7 +36,7 @@ def extract_questions(
"""
Infer search queries to retrieve relevant notes to answer user query
"""
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
@@ -159,8 +159,7 @@ def converse(
)
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
location_prompt = prompts.user_location.format(location=f"{location_data}")
system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name:

View File

@@ -369,7 +369,7 @@ async def infer_webpage_urls(
"""
Infer webpage links from the given query
"""
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
@@ -405,7 +405,7 @@ async def generate_online_subqueries(
"""
Generate subqueries from the given query
"""
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
@@ -535,8 +535,7 @@ async def generate_better_image_prompt(
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
location_prompt = prompts.user_location.format(location=f"{location_data}")
else:
location_prompt = "Unknown"

View File

@@ -26,6 +26,16 @@ class LocationData(BaseModel):
region: Optional[str]
country: Optional[str]
def __str__(self):
parts = []
if self.city:
parts.append(self.city)
if self.region:
parts.append(self.region)
if self.country:
parts.append(self.country)
return ", ".join(parts)
class FileFilterRequest(BaseModel):
filename: str