diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 4ce50cf2..cb16a4ab 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -55,6 +55,8 @@ from khoj.utils.helpers import ( ConversationCommand, command_descriptions, convert_image_to_webp, + get_country_code_from_timezone, + get_country_name_from_timezone, get_device, is_none_or_empty, ) @@ -556,8 +558,8 @@ async def chat( conversation_id = body.conversation_id city = body.city region = body.region - country = body.country - country_code = body.country_code + country = body.country or get_country_name_from_timezone(body.timezone) + country_code = body.country_code or get_country_code_from_timezone(body.timezone) timezone = body.timezone image = body.image diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 1ebb1fdd..77395cdc 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -9,6 +9,7 @@ import random import uuid from collections import OrderedDict from enum import Enum +from functools import lru_cache from importlib import import_module from importlib.metadata import version from itertools import islice @@ -24,6 +25,7 @@ import torch from asgiref.sync import sync_to_async from magika import Magika from PIL import Image +from pytz import country_names, country_timezones from khoj.utils import constants @@ -431,3 +433,24 @@ def convert_image_to_webp(image_bytes): webp_image_bytes = webp_image_io.getvalue() webp_image_io.close() return webp_image_bytes + + +@lru_cache +def tz_to_cc_map() -> dict[str, str]: + """Create a mapping of timezone to country code""" + timezone_country = {} + for countrycode in country_timezones: + timezones = country_timezones[countrycode] + for timezone in timezones: + timezone_country[timezone] = countrycode + return timezone_country + + +def get_country_code_from_timezone(tz: str) -> str: + """Get country code from timezone""" + return tz_to_cc_map().get(tz, "US") + + +def get_country_name_from_timezone(tz: str) -> str: + """Get country name from timezone""" + return country_names.get(get_country_code_from_timezone(tz), "United States")