mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 05:39:12 +00:00
Make Online Search Location Aware (#929)
## Overview Add user country code as context for doing online search with serper.dev API. This should find more user relevant results from online searches by Khoj ## Details ### Major - Default to using system clock to infer user timezone on js clients - Infer country from timezone when only timezone received by chat API - Localize online search results to user country when location available ### Minor - Add `__str__` func to `LocationData` class to deduplicate location string generation
This commit is contained in:
@@ -60,7 +60,8 @@
|
|||||||
let region = null;
|
let region = null;
|
||||||
let city = null;
|
let city = null;
|
||||||
let countryName = null;
|
let countryName = null;
|
||||||
let timezone = null;
|
let countryCode = null;
|
||||||
|
let timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
|
||||||
let chatMessageState = {
|
let chatMessageState = {
|
||||||
newResponseTextEl: null,
|
newResponseTextEl: null,
|
||||||
newResponseEl: null,
|
newResponseEl: null,
|
||||||
@@ -76,6 +77,7 @@
|
|||||||
region = data.region;
|
region = data.region;
|
||||||
city = data.city;
|
city = data.city;
|
||||||
countryName = data.country_name;
|
countryName = data.country_name;
|
||||||
|
countryCode = data.country_code;
|
||||||
timezone = data.timezone;
|
timezone = data.timezone;
|
||||||
})
|
})
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
@@ -157,6 +159,7 @@
|
|||||||
...(!!city && { city: city }),
|
...(!!city && { city: city }),
|
||||||
...(!!region && { region: region }),
|
...(!!region && { region: region }),
|
||||||
...(!!countryName && { country: countryName }),
|
...(!!countryName && { country: countryName }),
|
||||||
|
...(!!countryCode && { country_code: countryCode }),
|
||||||
...(!!timezone && { timezone: timezone }),
|
...(!!timezone && { timezone: timezone }),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -308,18 +308,19 @@
|
|||||||
<script src="./utils.js"></script>
|
<script src="./utils.js"></script>
|
||||||
<script src="./chatutils.js"></script>
|
<script src="./chatutils.js"></script>
|
||||||
<script>
|
<script>
|
||||||
|
|
||||||
let region = null;
|
let region = null;
|
||||||
let city = null;
|
let city = null;
|
||||||
let countryName = null;
|
let countryName = null;
|
||||||
let timezone = null;
|
let countryCode = null;
|
||||||
|
let timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
|
||||||
|
|
||||||
fetch("https://ipapi.co/json")
|
fetch("https://ipapi.co/json")
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
region = data.region;
|
|
||||||
city = data.city;
|
city = data.city;
|
||||||
|
region = data.region;
|
||||||
countryName = data.country_name;
|
countryName = data.country_name;
|
||||||
|
countryCode = data.country_code;
|
||||||
timezone = data.timezone;
|
timezone = data.timezone;
|
||||||
})
|
})
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
@@ -410,6 +411,7 @@
|
|||||||
...(!!city && { city: city }),
|
...(!!city && { city: city }),
|
||||||
...(!!region && { region: region }),
|
...(!!region && { region: region }),
|
||||||
...(!!countryName && { country: countryName }),
|
...(!!countryName && { country: countryName }),
|
||||||
|
...(!!countryCode && { country_code: countryCode }),
|
||||||
...(!!timezone && { timezone: timezone }),
|
...(!!timezone && { timezone: timezone }),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -33,9 +33,10 @@ interface ChatMessageState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
interface Location {
|
interface Location {
|
||||||
region: string;
|
region?: string;
|
||||||
city: string;
|
city?: string;
|
||||||
countryName: string;
|
countryName?: string;
|
||||||
|
countryCode?: string;
|
||||||
timezone: string;
|
timezone: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,7 +44,7 @@ export class KhojChatView extends KhojPaneView {
|
|||||||
result: string;
|
result: string;
|
||||||
setting: KhojSetting;
|
setting: KhojSetting;
|
||||||
waitingForLocation: boolean;
|
waitingForLocation: boolean;
|
||||||
location: Location;
|
location: Location = { timezone: Intl.DateTimeFormat().resolvedOptions().timeZone };
|
||||||
keyPressTimeout: NodeJS.Timeout | null = null;
|
keyPressTimeout: NodeJS.Timeout | null = null;
|
||||||
userMessages: string[] = []; // Store user sent messages for input history cycling
|
userMessages: string[] = []; // Store user sent messages for input history cycling
|
||||||
currentMessageIndex: number = -1; // Track current message index in userMessages array
|
currentMessageIndex: number = -1; // Track current message index in userMessages array
|
||||||
@@ -70,6 +71,7 @@ export class KhojChatView extends KhojPaneView {
|
|||||||
region: data.region,
|
region: data.region,
|
||||||
city: data.city,
|
city: data.city,
|
||||||
countryName: data.country_name,
|
countryName: data.country_name,
|
||||||
|
countryCode: data.country_code,
|
||||||
timezone: data.timezone,
|
timezone: data.timezone,
|
||||||
};
|
};
|
||||||
})
|
})
|
||||||
@@ -1056,12 +1058,11 @@ export class KhojChatView extends KhojPaneView {
|
|||||||
n: this.setting.resultsCount,
|
n: this.setting.resultsCount,
|
||||||
stream: true,
|
stream: true,
|
||||||
...(!!conversationId && { conversation_id: conversationId }),
|
...(!!conversationId && { conversation_id: conversationId }),
|
||||||
...(!!this.location && {
|
...(!!this.location && this.location.city && { city: this.location.city }),
|
||||||
city: this.location.city,
|
...(!!this.location && this.location.region && { region: this.location.region }),
|
||||||
region: this.location.region,
|
...(!!this.location && this.location.countryName && { country: this.location.countryName }),
|
||||||
country: this.location.countryName,
|
...(!!this.location && this.location.countryCode && { country_code: this.location.countryCode }),
|
||||||
timezone: this.location.timezone,
|
...(!!this.location && this.location.timezone && { timezone: this.location.timezone }),
|
||||||
}),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let newResponseEl = this.createKhojResponseDiv();
|
let newResponseEl = this.createKhojResponseDiv();
|
||||||
|
|||||||
@@ -518,12 +518,14 @@ function EditCard(props: EditCardProps) {
|
|||||||
updateQueryUrl += `&subject=${encodeURIComponent(values.subject)}`;
|
updateQueryUrl += `&subject=${encodeURIComponent(values.subject)}`;
|
||||||
}
|
}
|
||||||
updateQueryUrl += `&crontime=${encodeURIComponent(cronFrequency)}`;
|
updateQueryUrl += `&crontime=${encodeURIComponent(cronFrequency)}`;
|
||||||
if (props.locationData) {
|
if (props.locationData && props.locationData.city)
|
||||||
updateQueryUrl += `&city=${encodeURIComponent(props.locationData.city)}`;
|
updateQueryUrl += `&city=${encodeURIComponent(props.locationData.city)}`;
|
||||||
|
if (props.locationData && props.locationData.region)
|
||||||
updateQueryUrl += `®ion=${encodeURIComponent(props.locationData.region)}`;
|
updateQueryUrl += `®ion=${encodeURIComponent(props.locationData.region)}`;
|
||||||
|
if (props.locationData && props.locationData.country)
|
||||||
updateQueryUrl += `&country=${encodeURIComponent(props.locationData.country)}`;
|
updateQueryUrl += `&country=${encodeURIComponent(props.locationData.country)}`;
|
||||||
|
if (props.locationData && props.locationData.timezone)
|
||||||
updateQueryUrl += `&timezone=${encodeURIComponent(props.locationData.timezone)}`;
|
updateQueryUrl += `&timezone=${encodeURIComponent(props.locationData.timezone)}`;
|
||||||
}
|
|
||||||
|
|
||||||
let method = props.createNew ? "POST" : "PUT";
|
let method = props.createNew ? "POST" : "PUT";
|
||||||
|
|
||||||
|
|||||||
@@ -136,7 +136,9 @@ export default function Chat() {
|
|||||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
||||||
const [image64, setImage64] = useState<string>("");
|
const [image64, setImage64] = useState<string>("");
|
||||||
|
|
||||||
const locationData = useIPLocationData();
|
const locationData = useIPLocationData() || {
|
||||||
|
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
||||||
|
};
|
||||||
const authenticatedData = useAuthenticatedData();
|
const authenticatedData = useAuthenticatedData();
|
||||||
const isMobileWidth = useIsMobileWidth();
|
const isMobileWidth = useIsMobileWidth();
|
||||||
|
|
||||||
@@ -241,9 +243,10 @@ export default function Chat() {
|
|||||||
conversation_id: conversationId,
|
conversation_id: conversationId,
|
||||||
stream: true,
|
stream: true,
|
||||||
...(locationData && {
|
...(locationData && {
|
||||||
|
city: locationData.city,
|
||||||
region: locationData.region,
|
region: locationData.region,
|
||||||
country: locationData.country,
|
country: locationData.country,
|
||||||
city: locationData.city,
|
country_code: locationData.countryCode,
|
||||||
timezone: locationData.timezone,
|
timezone: locationData.timezone,
|
||||||
}),
|
}),
|
||||||
...(image64 && { image: image64 }),
|
...(image64 && { image: image64 }),
|
||||||
|
|||||||
@@ -2,13 +2,10 @@ import { useEffect, useState } from "react";
|
|||||||
import useSWR from "swr";
|
import useSWR from "swr";
|
||||||
|
|
||||||
export interface LocationData {
|
export interface LocationData {
|
||||||
ip: string;
|
city?: string;
|
||||||
city: string;
|
region?: string;
|
||||||
region: string;
|
country?: string;
|
||||||
country: string;
|
countryCode?: string;
|
||||||
postal: string;
|
|
||||||
latitude: number;
|
|
||||||
longitude: number;
|
|
||||||
timezone: string;
|
timezone: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,9 +47,7 @@ export function useIPLocationData() {
|
|||||||
{ revalidateOnFocus: false },
|
{ revalidateOnFocus: false },
|
||||||
);
|
);
|
||||||
|
|
||||||
if (locationDataError) return null;
|
if (locationDataError || !locationData) return;
|
||||||
if (!locationData) return null;
|
|
||||||
|
|
||||||
return locationData;
|
return locationData;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -111,7 +111,9 @@ export default function SharedChat() {
|
|||||||
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
|
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
|
||||||
const [image64, setImage64] = useState<string>("");
|
const [image64, setImage64] = useState<string>("");
|
||||||
|
|
||||||
const locationData = useIPLocationData();
|
const locationData = useIPLocationData() || {
|
||||||
|
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
||||||
|
};
|
||||||
const authenticatedData = useAuthenticatedData();
|
const authenticatedData = useAuthenticatedData();
|
||||||
const isMobileWidth = useIsMobileWidth();
|
const isMobileWidth = useIsMobileWidth();
|
||||||
|
|
||||||
@@ -231,6 +233,7 @@ export default function SharedChat() {
|
|||||||
region: locationData.region,
|
region: locationData.region,
|
||||||
country: locationData.country,
|
country: locationData.country,
|
||||||
city: locationData.city,
|
city: locationData.city,
|
||||||
|
country_code: locationData.countryCode,
|
||||||
timezone: locationData.timezone,
|
timezone: locationData.timezone,
|
||||||
}),
|
}),
|
||||||
...(image64 && { image: image64 }),
|
...(image64 && { image: image64 }),
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ def extract_questions_anthropic(
|
|||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
"""
|
"""
|
||||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
location = f"{location_data}" if location_data else "Unknown"
|
||||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||||
|
|
||||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||||
@@ -158,8 +158,7 @@ def converse_anthropic(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if location_data:
|
if location_data:
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||||
location_prompt = prompts.user_location.format(location=location)
|
|
||||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||||
|
|
||||||
if user_name:
|
if user_name:
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ def extract_questions_gemini(
|
|||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
"""
|
"""
|
||||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
location = f"{location_data}" if location_data else "Unknown"
|
||||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||||
|
|
||||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||||
@@ -163,8 +163,7 @@ def converse_gemini(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if location_data:
|
if location_data:
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||||
location_prompt = prompts.user_location.format(location=location)
|
|
||||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||||
|
|
||||||
if user_name:
|
if user_name:
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ def extract_questions_offline(
|
|||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||||
|
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
location = f"{location_data}" if location_data else "Unknown"
|
||||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||||
|
|
||||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||||
@@ -171,8 +171,7 @@ def converse_offline(
|
|||||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||||
|
|
||||||
if location_data:
|
if location_data:
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||||
location_prompt = prompts.user_location.format(location=location)
|
|
||||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||||
|
|
||||||
if user_name:
|
if user_name:
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ def extract_questions(
|
|||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
"""
|
"""
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
location = f"{location_data}" if location_data else "Unknown"
|
||||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||||
|
|
||||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||||
@@ -159,8 +159,7 @@ def converse(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if location_data:
|
if location_data:
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||||
location_prompt = prompts.user_location.format(location=location)
|
|
||||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||||
|
|
||||||
if user_name:
|
if user_name:
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from collections import defaultdict
|
|||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from markdownify import markdownify
|
from markdownify import markdownify
|
||||||
|
|
||||||
@@ -80,7 +79,7 @@ async def search_online(
|
|||||||
|
|
||||||
with timer(f"Internet searches for {list(subqueries)} took", logger):
|
with timer(f"Internet searches for {list(subqueries)} took", logger):
|
||||||
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
|
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
|
||||||
search_tasks = [search_func(subquery) for subquery in subqueries]
|
search_tasks = [search_func(subquery, location) for subquery in subqueries]
|
||||||
search_results = await asyncio.gather(*search_tasks)
|
search_results = await asyncio.gather(*search_tasks)
|
||||||
response_dict = {subquery: search_result for subquery, search_result in search_results}
|
response_dict = {subquery: search_result for subquery, search_result in search_results}
|
||||||
|
|
||||||
@@ -115,8 +114,9 @@ async def search_online(
|
|||||||
yield response_dict
|
yield response_dict
|
||||||
|
|
||||||
|
|
||||||
async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
|
async def search_with_google(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]:
|
||||||
payload = json.dumps({"q": query})
|
country_code = location.country_code.lower() if location and location.country_code else "us"
|
||||||
|
payload = json.dumps({"q": query, "gl": country_code})
|
||||||
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
|
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
@@ -220,7 +220,7 @@ async def read_webpage_with_jina(web_url: str) -> str:
|
|||||||
return response_json["data"]["content"]
|
return response_json["data"]["content"]
|
||||||
|
|
||||||
|
|
||||||
async def search_with_jina(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
|
async def search_with_jina(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]:
|
||||||
encoded_query = urllib.parse.quote(query)
|
encoded_query = urllib.parse.quote(query)
|
||||||
jina_search_api_url = f"{JINA_SEARCH_API_URL}/{encoded_query}"
|
jina_search_api_url = f"{JINA_SEARCH_API_URL}/{encoded_query}"
|
||||||
headers = {"Accept": "application/json"}
|
headers = {"Accept": "application/json"}
|
||||||
|
|||||||
@@ -55,6 +55,8 @@ from khoj.utils.helpers import (
|
|||||||
ConversationCommand,
|
ConversationCommand,
|
||||||
command_descriptions,
|
command_descriptions,
|
||||||
convert_image_to_webp,
|
convert_image_to_webp,
|
||||||
|
get_country_code_from_timezone,
|
||||||
|
get_country_name_from_timezone,
|
||||||
get_device,
|
get_device,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
)
|
)
|
||||||
@@ -529,6 +531,7 @@ class ChatRequestBody(BaseModel):
|
|||||||
city: Optional[str] = None
|
city: Optional[str] = None
|
||||||
region: Optional[str] = None
|
region: Optional[str] = None
|
||||||
country: Optional[str] = None
|
country: Optional[str] = None
|
||||||
|
country_code: Optional[str] = None
|
||||||
timezone: Optional[str] = None
|
timezone: Optional[str] = None
|
||||||
image: Optional[str] = None
|
image: Optional[str] = None
|
||||||
create_new: Optional[bool] = False
|
create_new: Optional[bool] = False
|
||||||
@@ -556,7 +559,8 @@ async def chat(
|
|||||||
conversation_id = body.conversation_id
|
conversation_id = body.conversation_id
|
||||||
city = body.city
|
city = body.city
|
||||||
region = body.region
|
region = body.region
|
||||||
country = body.country
|
country = body.country or get_country_name_from_timezone(body.timezone)
|
||||||
|
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
||||||
timezone = body.timezone
|
timezone = body.timezone
|
||||||
image = body.image
|
image = body.image
|
||||||
|
|
||||||
@@ -658,8 +662,8 @@ async def chat(
|
|||||||
|
|
||||||
user_name = await aget_user_name(user)
|
user_name = await aget_user_name(user)
|
||||||
location = None
|
location = None
|
||||||
if city or region or country:
|
if city or region or country or country_code:
|
||||||
location = LocationData(city=city, region=region, country=country)
|
location = LocationData(city=city, region=region, country=country, country_code=country_code)
|
||||||
|
|
||||||
if is_query_empty(q):
|
if is_query_empty(q):
|
||||||
async for result in send_llm_response("Please ask your query to get started."):
|
async for result in send_llm_response("Please ask your query to get started."):
|
||||||
|
|||||||
@@ -369,7 +369,7 @@ async def infer_webpage_urls(
|
|||||||
"""
|
"""
|
||||||
Infer webpage links from the given query
|
Infer webpage links from the given query
|
||||||
"""
|
"""
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
location = f"{location_data}" if location_data else "Unknown"
|
||||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||||
chat_history = construct_chat_history(conversation_history)
|
chat_history = construct_chat_history(conversation_history)
|
||||||
|
|
||||||
@@ -405,7 +405,7 @@ async def generate_online_subqueries(
|
|||||||
"""
|
"""
|
||||||
Generate subqueries from the given query
|
Generate subqueries from the given query
|
||||||
"""
|
"""
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
location = f"{location_data}" if location_data else "Unknown"
|
||||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||||
chat_history = construct_chat_history(conversation_history)
|
chat_history = construct_chat_history(conversation_history)
|
||||||
|
|
||||||
@@ -535,8 +535,7 @@ async def generate_better_image_prompt(
|
|||||||
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
|
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
|
||||||
|
|
||||||
if location_data:
|
if location_data:
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||||
location_prompt = prompts.user_location.format(location=location)
|
|
||||||
else:
|
else:
|
||||||
location_prompt = "Unknown"
|
location_prompt = "Unknown"
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import random
|
|||||||
import uuid
|
import uuid
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from functools import lru_cache
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
@@ -24,6 +25,7 @@ import torch
|
|||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from magika import Magika
|
from magika import Magika
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from pytz import country_names, country_timezones
|
||||||
|
|
||||||
from khoj.utils import constants
|
from khoj.utils import constants
|
||||||
|
|
||||||
@@ -431,3 +433,24 @@ def convert_image_to_webp(image_bytes):
|
|||||||
webp_image_bytes = webp_image_io.getvalue()
|
webp_image_bytes = webp_image_io.getvalue()
|
||||||
webp_image_io.close()
|
webp_image_io.close()
|
||||||
return webp_image_bytes
|
return webp_image_bytes
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def tz_to_cc_map() -> dict[str, str]:
|
||||||
|
"""Create a mapping of timezone to country code"""
|
||||||
|
timezone_country = {}
|
||||||
|
for countrycode in country_timezones:
|
||||||
|
timezones = country_timezones[countrycode]
|
||||||
|
for timezone in timezones:
|
||||||
|
timezone_country[timezone] = countrycode
|
||||||
|
return timezone_country
|
||||||
|
|
||||||
|
|
||||||
|
def get_country_code_from_timezone(tz: str) -> str:
|
||||||
|
"""Get country code from timezone"""
|
||||||
|
return tz_to_cc_map().get(tz, "US")
|
||||||
|
|
||||||
|
|
||||||
|
def get_country_name_from_timezone(tz: str) -> str:
|
||||||
|
"""Get country name from timezone"""
|
||||||
|
return country_names.get(get_country_code_from_timezone(tz), "United States")
|
||||||
|
|||||||
@@ -25,6 +25,17 @@ class LocationData(BaseModel):
|
|||||||
city: Optional[str]
|
city: Optional[str]
|
||||||
region: Optional[str]
|
region: Optional[str]
|
||||||
country: Optional[str]
|
country: Optional[str]
|
||||||
|
country_code: Optional[str]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
parts = []
|
||||||
|
if self.city:
|
||||||
|
parts.append(self.city)
|
||||||
|
if self.region:
|
||||||
|
parts.append(self.region)
|
||||||
|
if self.country:
|
||||||
|
parts.append(self.country)
|
||||||
|
return ", ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
class FileFilterRequest(BaseModel):
|
class FileFilterRequest(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user