Make Online Search Location Aware (#929)

## Overview
Add user country code as context for doing online search with serper.dev API.
This should find more user relevant results from online searches by Khoj

## Details
### Major
- Default to using system clock to infer user timezone on js clients
- Infer country from timezone when only timezone received by chat API
- Localize online search results to user country when location available

### Minor
- Add `__str__` func to `LocationData` class to deduplicate location string generation
This commit is contained in:
Debanjum
2024-10-03 12:33:47 -07:00
committed by GitHub
16 changed files with 95 additions and 53 deletions

View File

@@ -60,7 +60,8 @@
let region = null; let region = null;
let city = null; let city = null;
let countryName = null; let countryName = null;
let timezone = null; let countryCode = null;
let timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
let chatMessageState = { let chatMessageState = {
newResponseTextEl: null, newResponseTextEl: null,
newResponseEl: null, newResponseEl: null,
@@ -76,6 +77,7 @@
region = data.region; region = data.region;
city = data.city; city = data.city;
countryName = data.country_name; countryName = data.country_name;
countryCode = data.country_code;
timezone = data.timezone; timezone = data.timezone;
}) })
.catch(err => { .catch(err => {
@@ -157,6 +159,7 @@
...(!!city && { city: city }), ...(!!city && { city: city }),
...(!!region && { region: region }), ...(!!region && { region: region }),
...(!!countryName && { country: countryName }), ...(!!countryName && { country: countryName }),
...(!!countryCode && { country_code: countryCode }),
...(!!timezone && { timezone: timezone }), ...(!!timezone && { timezone: timezone }),
}; };

View File

@@ -308,18 +308,19 @@
<script src="./utils.js"></script> <script src="./utils.js"></script>
<script src="./chatutils.js"></script> <script src="./chatutils.js"></script>
<script> <script>
let region = null; let region = null;
let city = null; let city = null;
let countryName = null; let countryName = null;
let timezone = null; let countryCode = null;
let timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
fetch("https://ipapi.co/json") fetch("https://ipapi.co/json")
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
region = data.region;
city = data.city; city = data.city;
region = data.region;
countryName = data.country_name; countryName = data.country_name;
countryCode = data.country_code;
timezone = data.timezone; timezone = data.timezone;
}) })
.catch(err => { .catch(err => {
@@ -410,6 +411,7 @@
...(!!city && { city: city }), ...(!!city && { city: city }),
...(!!region && { region: region }), ...(!!region && { region: region }),
...(!!countryName && { country: countryName }), ...(!!countryName && { country: countryName }),
...(!!countryCode && { country_code: countryCode }),
...(!!timezone && { timezone: timezone }), ...(!!timezone && { timezone: timezone }),
}; };

View File

@@ -33,9 +33,10 @@ interface ChatMessageState {
} }
interface Location { interface Location {
region: string; region?: string;
city: string; city?: string;
countryName: string; countryName?: string;
countryCode?: string;
timezone: string; timezone: string;
} }
@@ -43,7 +44,7 @@ export class KhojChatView extends KhojPaneView {
result: string; result: string;
setting: KhojSetting; setting: KhojSetting;
waitingForLocation: boolean; waitingForLocation: boolean;
location: Location; location: Location = { timezone: Intl.DateTimeFormat().resolvedOptions().timeZone };
keyPressTimeout: NodeJS.Timeout | null = null; keyPressTimeout: NodeJS.Timeout | null = null;
userMessages: string[] = []; // Store user sent messages for input history cycling userMessages: string[] = []; // Store user sent messages for input history cycling
currentMessageIndex: number = -1; // Track current message index in userMessages array currentMessageIndex: number = -1; // Track current message index in userMessages array
@@ -70,6 +71,7 @@ export class KhojChatView extends KhojPaneView {
region: data.region, region: data.region,
city: data.city, city: data.city,
countryName: data.country_name, countryName: data.country_name,
countryCode: data.country_code,
timezone: data.timezone, timezone: data.timezone,
}; };
}) })
@@ -1056,12 +1058,11 @@ export class KhojChatView extends KhojPaneView {
n: this.setting.resultsCount, n: this.setting.resultsCount,
stream: true, stream: true,
...(!!conversationId && { conversation_id: conversationId }), ...(!!conversationId && { conversation_id: conversationId }),
...(!!this.location && { ...(!!this.location && this.location.city && { city: this.location.city }),
city: this.location.city, ...(!!this.location && this.location.region && { region: this.location.region }),
region: this.location.region, ...(!!this.location && this.location.countryName && { country: this.location.countryName }),
country: this.location.countryName, ...(!!this.location && this.location.countryCode && { country_code: this.location.countryCode }),
timezone: this.location.timezone, ...(!!this.location && this.location.timezone && { timezone: this.location.timezone }),
}),
}; };
let newResponseEl = this.createKhojResponseDiv(); let newResponseEl = this.createKhojResponseDiv();

View File

@@ -518,12 +518,14 @@ function EditCard(props: EditCardProps) {
updateQueryUrl += `&subject=${encodeURIComponent(values.subject)}`; updateQueryUrl += `&subject=${encodeURIComponent(values.subject)}`;
} }
updateQueryUrl += `&crontime=${encodeURIComponent(cronFrequency)}`; updateQueryUrl += `&crontime=${encodeURIComponent(cronFrequency)}`;
if (props.locationData) { if (props.locationData && props.locationData.city)
updateQueryUrl += `&city=${encodeURIComponent(props.locationData.city)}`; updateQueryUrl += `&city=${encodeURIComponent(props.locationData.city)}`;
if (props.locationData && props.locationData.region)
updateQueryUrl += `&region=${encodeURIComponent(props.locationData.region)}`; updateQueryUrl += `&region=${encodeURIComponent(props.locationData.region)}`;
if (props.locationData && props.locationData.country)
updateQueryUrl += `&country=${encodeURIComponent(props.locationData.country)}`; updateQueryUrl += `&country=${encodeURIComponent(props.locationData.country)}`;
if (props.locationData && props.locationData.timezone)
updateQueryUrl += `&timezone=${encodeURIComponent(props.locationData.timezone)}`; updateQueryUrl += `&timezone=${encodeURIComponent(props.locationData.timezone)}`;
}
let method = props.createNew ? "POST" : "PUT"; let method = props.createNew ? "POST" : "PUT";

View File

@@ -136,7 +136,9 @@ export default function Chat() {
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]); const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
const [image64, setImage64] = useState<string>(""); const [image64, setImage64] = useState<string>("");
const locationData = useIPLocationData(); const locationData = useIPLocationData() || {
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
};
const authenticatedData = useAuthenticatedData(); const authenticatedData = useAuthenticatedData();
const isMobileWidth = useIsMobileWidth(); const isMobileWidth = useIsMobileWidth();
@@ -241,9 +243,10 @@ export default function Chat() {
conversation_id: conversationId, conversation_id: conversationId,
stream: true, stream: true,
...(locationData && { ...(locationData && {
city: locationData.city,
region: locationData.region, region: locationData.region,
country: locationData.country, country: locationData.country,
city: locationData.city, country_code: locationData.countryCode,
timezone: locationData.timezone, timezone: locationData.timezone,
}), }),
...(image64 && { image: image64 }), ...(image64 && { image: image64 }),

View File

@@ -2,13 +2,10 @@ import { useEffect, useState } from "react";
import useSWR from "swr"; import useSWR from "swr";
export interface LocationData { export interface LocationData {
ip: string; city?: string;
city: string; region?: string;
region: string; country?: string;
country: string; countryCode?: string;
postal: string;
latitude: number;
longitude: number;
timezone: string; timezone: string;
} }
@@ -50,9 +47,7 @@ export function useIPLocationData() {
{ revalidateOnFocus: false }, { revalidateOnFocus: false },
); );
if (locationDataError) return null; if (locationDataError || !locationData) return;
if (!locationData) return null;
return locationData; return locationData;
} }

View File

@@ -111,7 +111,9 @@ export default function SharedChat() {
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined); const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
const [image64, setImage64] = useState<string>(""); const [image64, setImage64] = useState<string>("");
const locationData = useIPLocationData(); const locationData = useIPLocationData() || {
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
};
const authenticatedData = useAuthenticatedData(); const authenticatedData = useAuthenticatedData();
const isMobileWidth = useIsMobileWidth(); const isMobileWidth = useIsMobileWidth();
@@ -231,6 +233,7 @@ export default function SharedChat() {
region: locationData.region, region: locationData.region,
country: locationData.country, country: locationData.country,
city: locationData.city, city: locationData.city,
country_code: locationData.countryCode,
timezone: locationData.timezone, timezone: locationData.timezone,
}), }),
...(image64 && { image: image64 }), ...(image64 && { image: image64 }),

View File

@@ -32,7 +32,7 @@ def extract_questions_anthropic(
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
""" """
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
@@ -158,8 +158,7 @@ def converse_anthropic(
) )
if location_data: if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}" location_prompt = prompts.user_location.format(location=f"{location_data}")
location_prompt = prompts.user_location.format(location=location)
system_prompt = f"{system_prompt}\n{location_prompt}" system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name: if user_name:

View File

@@ -33,7 +33,7 @@ def extract_questions_gemini(
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
""" """
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
@@ -163,8 +163,7 @@ def converse_gemini(
) )
if location_data: if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}" location_prompt = prompts.user_location.format(location=f"{location_data}")
location_prompt = prompts.user_location.format(location=location)
system_prompt = f"{system_prompt}\n{location_prompt}" system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name: if user_name:

View File

@@ -46,7 +46,7 @@ def extract_questions_offline(
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
@@ -171,8 +171,7 @@ def converse_offline(
conversation_primer = prompts.query_prompt.format(query=user_query) conversation_primer = prompts.query_prompt.format(query=user_query)
if location_data: if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}" location_prompt = prompts.user_location.format(location=f"{location_data}")
location_prompt = prompts.user_location.format(location=location)
system_prompt = f"{system_prompt}\n{location_prompt}" system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name: if user_name:

View File

@@ -36,7 +36,7 @@ def extract_questions(
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
""" """
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
@@ -159,8 +159,7 @@ def converse(
) )
if location_data: if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}" location_prompt = prompts.user_location.format(location=f"{location_data}")
location_prompt = prompts.user_location.format(location=location)
system_prompt = f"{system_prompt}\n{location_prompt}" system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name: if user_name:

View File

@@ -7,7 +7,6 @@ from collections import defaultdict
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import aiohttp import aiohttp
import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from markdownify import markdownify from markdownify import markdownify
@@ -80,7 +79,7 @@ async def search_online(
with timer(f"Internet searches for {list(subqueries)} took", logger): with timer(f"Internet searches for {list(subqueries)} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
search_tasks = [search_func(subquery) for subquery in subqueries] search_tasks = [search_func(subquery, location) for subquery in subqueries]
search_results = await asyncio.gather(*search_tasks) search_results = await asyncio.gather(*search_tasks)
response_dict = {subquery: search_result for subquery, search_result in search_results} response_dict = {subquery: search_result for subquery, search_result in search_results}
@@ -115,8 +114,9 @@ async def search_online(
yield response_dict yield response_dict
async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]: async def search_with_google(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]:
payload = json.dumps({"q": query}) country_code = location.country_code.lower() if location and location.country_code else "us"
payload = json.dumps({"q": query, "gl": country_code})
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"} headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@@ -220,7 +220,7 @@ async def read_webpage_with_jina(web_url: str) -> str:
return response_json["data"]["content"] return response_json["data"]["content"]
async def search_with_jina(query: str) -> Tuple[str, Dict[str, List[Dict]]]: async def search_with_jina(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]:
encoded_query = urllib.parse.quote(query) encoded_query = urllib.parse.quote(query)
jina_search_api_url = f"{JINA_SEARCH_API_URL}/{encoded_query}" jina_search_api_url = f"{JINA_SEARCH_API_URL}/{encoded_query}"
headers = {"Accept": "application/json"} headers = {"Accept": "application/json"}

View File

@@ -55,6 +55,8 @@ from khoj.utils.helpers import (
ConversationCommand, ConversationCommand,
command_descriptions, command_descriptions,
convert_image_to_webp, convert_image_to_webp,
get_country_code_from_timezone,
get_country_name_from_timezone,
get_device, get_device,
is_none_or_empty, is_none_or_empty,
) )
@@ -529,6 +531,7 @@ class ChatRequestBody(BaseModel):
city: Optional[str] = None city: Optional[str] = None
region: Optional[str] = None region: Optional[str] = None
country: Optional[str] = None country: Optional[str] = None
country_code: Optional[str] = None
timezone: Optional[str] = None timezone: Optional[str] = None
image: Optional[str] = None image: Optional[str] = None
create_new: Optional[bool] = False create_new: Optional[bool] = False
@@ -556,7 +559,8 @@ async def chat(
conversation_id = body.conversation_id conversation_id = body.conversation_id
city = body.city city = body.city
region = body.region region = body.region
country = body.country country = body.country or get_country_name_from_timezone(body.timezone)
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
timezone = body.timezone timezone = body.timezone
image = body.image image = body.image
@@ -658,8 +662,8 @@ async def chat(
user_name = await aget_user_name(user) user_name = await aget_user_name(user)
location = None location = None
if city or region or country: if city or region or country or country_code:
location = LocationData(city=city, region=region, country=country) location = LocationData(city=city, region=region, country=country, country_code=country_code)
if is_query_empty(q): if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started."): async for result in send_llm_response("Please ask your query to get started."):

View File

@@ -369,7 +369,7 @@ async def infer_webpage_urls(
""" """
Infer webpage links from the given query Infer webpage links from the given query
""" """
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
@@ -405,7 +405,7 @@ async def generate_online_subqueries(
""" """
Generate subqueries from the given query Generate subqueries from the given query
""" """
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
@@ -535,8 +535,7 @@ async def generate_better_image_prompt(
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
if location_data: if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}" location_prompt = prompts.user_location.format(location=f"{location_data}")
location_prompt = prompts.user_location.format(location=location)
else: else:
location_prompt = "Unknown" location_prompt = "Unknown"

View File

@@ -9,6 +9,7 @@ import random
import uuid import uuid
from collections import OrderedDict from collections import OrderedDict
from enum import Enum from enum import Enum
from functools import lru_cache
from importlib import import_module from importlib import import_module
from importlib.metadata import version from importlib.metadata import version
from itertools import islice from itertools import islice
@@ -24,6 +25,7 @@ import torch
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from magika import Magika from magika import Magika
from PIL import Image from PIL import Image
from pytz import country_names, country_timezones
from khoj.utils import constants from khoj.utils import constants
@@ -431,3 +433,24 @@ def convert_image_to_webp(image_bytes):
webp_image_bytes = webp_image_io.getvalue() webp_image_bytes = webp_image_io.getvalue()
webp_image_io.close() webp_image_io.close()
return webp_image_bytes return webp_image_bytes
@lru_cache
def tz_to_cc_map() -> dict[str, str]:
"""Create a mapping of timezone to country code"""
timezone_country = {}
for countrycode in country_timezones:
timezones = country_timezones[countrycode]
for timezone in timezones:
timezone_country[timezone] = countrycode
return timezone_country
def get_country_code_from_timezone(tz: str) -> str:
"""Get country code from timezone"""
return tz_to_cc_map().get(tz, "US")
def get_country_name_from_timezone(tz: str) -> str:
"""Get country name from timezone"""
return country_names.get(get_country_code_from_timezone(tz), "United States")

View File

@@ -25,6 +25,17 @@ class LocationData(BaseModel):
city: Optional[str] city: Optional[str]
region: Optional[str] region: Optional[str]
country: Optional[str] country: Optional[str]
country_code: Optional[str]
def __str__(self):
parts = []
if self.city:
parts.append(self.city)
if self.region:
parts.append(self.region)
if self.country:
parts.append(self.country)
return ", ".join(parts)
class FileFilterRequest(BaseModel): class FileFilterRequest(BaseModel):