Files
khoj/src/khoj/utils/helpers.py
2024-10-09 23:38:20 -07:00

480 lines
17 KiB
Python

from __future__ import annotations # to avoid quoting type hints
import datetime
import io
import logging
import os
import platform
import random
import uuid
from collections import OrderedDict
from enum import Enum
from functools import lru_cache
from importlib import import_module
from importlib.metadata import version
from itertools import islice
from os import path
from pathlib import Path
from time import perf_counter
from typing import TYPE_CHECKING, Optional, Union
from urllib.parse import urlparse
import psutil
import requests
import torch
from asgiref.sync import sync_to_async
from magika import Magika
from PIL import Image
from pytz import country_names, country_timezones
from khoj.utils import constants
if TYPE_CHECKING:
from sentence_transformers import CrossEncoder, SentenceTransformer
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import AppConfig
# Initialize Magika for file type identification
magika = Magika()
class AsyncIteratorWrapper:
def __init__(self, obj):
self._it = iter(obj)
def __aiter__(self):
return self
async def __anext__(self):
try:
value = await self.next_async()
except StopAsyncIteration:
return
return value
@sync_to_async
def next_async(self):
try:
return next(self._it)
except StopIteration:
raise StopAsyncIteration
def is_none_or_empty(item):
return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
def to_snake_case_from_dash(item: str):
return item.replace("_", "-")
def get_absolute_path(filepath: Union[str, Path]) -> str:
return str(Path(filepath).expanduser().absolute())
def resolve_absolute_path(filepath: Union[str, Optional[Path]], strict=False) -> Path:
return Path(filepath).expanduser().absolute().resolve(strict=strict)
def get_from_dict(dictionary, *args):
"""null-aware get from a nested dictionary
Returns: dictionary[args[0]][args[1]]... or None if any keys missing"""
current = dictionary
for arg in args:
if not hasattr(current, "__iter__") or not arg in current:
return None
current = current[arg]
return current
def merge_dicts(priority_dict: dict, default_dict: dict):
merged_dict = priority_dict.copy()
for key, _ in default_dict.items():
if key not in priority_dict:
merged_dict[key] = default_dict[key]
elif isinstance(priority_dict[key], dict) and isinstance(default_dict[key], dict):
merged_dict[key] = merge_dicts(priority_dict[key], default_dict[key])
return merged_dict
def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]:
"Get file type from file mime type"
# Extract encoding from file_type
encoding = file_type.split("=")[1].strip().lower() if ";" in file_type else None
file_type = file_type.split(";")[0].strip() if ";" in file_type else file_type
# Infer content type from reading file content
try:
content_group = magika.identify_bytes(file_content).output.group
except Exception:
# Fallback to using just file type if content type cannot be inferred
content_group = "unknown"
if file_type in ["text/markdown"]:
return "markdown", encoding
elif file_type in ["text/org"]:
return "org", encoding
elif file_type in ["application/pdf"]:
return "pdf", encoding
elif file_type in ["application/msword", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"]:
return "docx", encoding
elif file_type in ["image/jpeg"]:
return "image", encoding
elif file_type in ["image/png"]:
return "image", encoding
elif content_group in ["code", "text"]:
return "plaintext", encoding
else:
return "other", encoding
def load_model(
model_name: str, model_type, model_dir=None, device: str = None
) -> Union[BaseEncoder, SentenceTransformer, CrossEncoder]:
"Load model from disk or huggingface"
# Construct model path
logger = logging.getLogger(__name__)
model_path = path.join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
# Load model from model_path if it exists there
model_type_class = get_class_by_name(model_type) if isinstance(model_type, str) else model_type
if model_path is not None and resolve_absolute_path(model_path).exists():
logger.debug(f"Loading {model_name} model from disk")
model = model_type_class(get_absolute_path(model_path), device=device)
# Else load the model from the model_name
else:
logger.info(f"🤖 Downloading {model_name} model from web")
model = model_type_class(model_name, device=device)
if model_path is not None:
logger.info(f"📩 Saved {model_name} model to disk")
model.save(model_path)
return model
def get_class_by_name(name: str) -> object:
"Returns the class object from name string"
module_name, class_name = name.rsplit(".", 1)
return getattr(import_module(module_name), class_name)
class timer:
"""Context manager to log time taken for a block of code to run"""
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None):
self.message = message
self.logger = logger
self.device = device
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, *_):
elapsed = perf_counter() - self.start
if self.device is None:
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds")
else:
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}")
class LRU(OrderedDict):
def __init__(self, *args, capacity=128, **kwargs):
self.capacity = capacity
super().__init__(*args, **kwargs)
def __getitem__(self, key):
value = super().__getitem__(key)
self.move_to_end(key)
return value
def __setitem__(self, key, value):
super().__setitem__(key, value)
if len(self) > self.capacity:
oldest = next(iter(self))
del self[oldest]
def get_server_id():
"""Get, Generate Persistent, Random ID per server install.
Helps count distinct khoj servers deployed.
Maintains anonymity by using non-PII random id."""
# Initialize server_id to None
server_id = None
# Expand path to the khoj env file. It contains persistent internal app data
app_env_filename = path.expanduser(constants.app_env_filepath)
# Check if the file exists
if path.exists(app_env_filename):
# Read the contents of the file
with open(app_env_filename, "r") as f:
contents = f.readlines()
# Extract the server_id from the contents
for line in contents:
key, value = line.strip().split("=")
if key.strip() == "server_id":
server_id = value.strip()
break
# If server_id is not found, generate and write to env file
if server_id is None:
# If server_id is not found, generate a new one
server_id = str(uuid.uuid4())
with open(app_env_filename, "a") as f:
f.write("server_id=" + server_id + "\n")
else:
# If server_id is not found, generate a new one
server_id = str(uuid.uuid4())
# Create khoj config directory if it doesn't exist
os.makedirs(path.dirname(app_env_filename), exist_ok=True)
# Write the server_id to the env file
with open(app_env_filename, "w") as f:
f.write("server_id=" + server_id + "\n")
return server_id
def telemetry_disabled(app_config: AppConfig):
return not app_config or not app_config.should_log_telemetry
def log_telemetry(
telemetry_type: str,
api: str = None,
client: Optional[str] = None,
app_config: Optional[AppConfig] = None,
properties: dict = None,
):
"""Log basic app usage telemetry like client, os, api called"""
# Do not log usage telemetry, if telemetry is disabled via app config
if telemetry_disabled(app_config):
return []
if properties.get("server_id") is None:
properties["server_id"] = get_server_id()
# Populate telemetry data to log
request_body = {
"telemetry_type": telemetry_type,
"server_version": version("khoj"),
"os": platform.system(),
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
request_body.update(properties or {})
if api:
# API endpoint on server called by client
request_body["api"] = api
if client:
# Client from which the API was called. E.g. Emacs, Obsidian
request_body["client"] = client
# Log telemetry data to telemetry endpoint
return request_body
def get_device_memory() -> int:
"""Get device memory in GB"""
device = get_device()
if device.type == "cuda":
return torch.cuda.get_device_properties(device).total_memory
elif device.type == "mps":
return torch.mps.driver_allocated_memory()
else:
return psutil.virtual_memory().total
def get_device() -> torch.device:
"""Get device to run model on"""
if torch.cuda.is_available():
# Use CUDA GPU
return torch.device("cuda:0")
elif torch.backends.mps.is_available():
# Use Apple M1 Metal Acceleration
return torch.device("mps")
else:
return torch.device("cpu")
class ConversationCommand(str, Enum):
Default = "default"
General = "general"
Notes = "notes"
Help = "help"
Online = "online"
Webpage = "webpage"
Code = "code"
Image = "image"
Text = "text"
Automation = "automation"
AutomatedTask = "automated_task"
Summarize = "summarize"
command_descriptions = {
ConversationCommand.General: "Only talk about information that relies on Khoj's general knowledge, not your personal knowledge base.",
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
ConversationCommand.Online: "Search for information on the internet.",
ConversationCommand.Webpage: "Get information from webpage suggested by you.",
ConversationCommand.Code: "Run Python code to parse information, run complex calculations, create documents and charts.",
ConversationCommand.Image: "Generate images by describing your imagination in words.",
ConversationCommand.Automation: "Automatically run your query at a specified time or interval.",
ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation",
ConversationCommand.Summarize: "Get help with a question pertaining to an entire document.",
}
command_descriptions_for_agent = {
ConversationCommand.General: "Agent can use the agents knowledge base and general knowledge.",
ConversationCommand.Notes: "Agent can search the users knowledge base for information.",
ConversationCommand.Online: "Agent can search the internet for information.",
ConversationCommand.Webpage: "Agent can read suggested web pages for information.",
ConversationCommand.Summarize: "Agent can read an entire document. Agents knowledge base must be a single document.",
}
tool_descriptions_for_llm = {
ConversationCommand.Default: "To use a mix of your internal knowledge and the user's personal knowledge, or if you don't entirely understand the query.",
ConversationCommand.General: "To use when you can answer the question without any outside information or personal knowledge",
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.",
ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.",
}
function_calling_description_for_llm = {
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet.",
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.",
}
mode_descriptions_for_llm = {
ConversationCommand.Image: "Use this if the user is requesting you to generate images based on their description. This does not support generating charts or graphs.",
ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.",
ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.",
}
mode_descriptions_for_agent = {
ConversationCommand.Image: "Agent can generate images in response. It cannot not use this to generate charts and graphs.",
ConversationCommand.Text: "Agent can generate text in response.",
}
class ImageIntentType(Enum):
"""
Chat message intent by Khoj for image responses.
Marks the schema used to reference image in chat messages
"""
# Images as Inline PNG
TEXT_TO_IMAGE = "text-to-image"
# Images as URLs
TEXT_TO_IMAGE2 = "text-to-image2"
# Images as Inline WebP
TEXT_TO_IMAGE_V3 = "text-to-image-v3"
def generate_random_name():
# List of adjectives and nouns to choose from
adjectives = [
"happy",
"serendipitous",
"exuberant",
"calm",
"brave",
"scared",
"energetic",
"chivalrous",
"kind",
"suave",
]
nouns = ["dog", "cat", "falcon", "whale", "turtle", "rabbit", "hamster", "snake", "spider", "elephant"]
# Select two random words from the lists
adjective = random.choice(adjectives)
noun = random.choice(nouns)
# Combine the words to form a name
name = f"{adjective} {noun}"
return name
def batcher(iterable, max_n):
"Split an iterable into chunks of size max_n"
it = iter(iterable)
while True:
chunk = list(islice(it, max_n))
if not chunk:
return
yield (x for x in chunk if x is not None)
def is_env_var_true(env_var: str, default: str = "false") -> bool:
"""Get state of boolean environment variable"""
return os.getenv(env_var, default).lower() == "true"
def in_debug_mode():
"""Check if Khoj is running in debug mode.
Set KHOJ_DEBUG environment variable to true to enable debug mode."""
return is_env_var_true("KHOJ_DEBUG")
def is_valid_url(url: str) -> bool:
"""Check if a string is a valid URL"""
try:
result = urlparse(url.strip())
return all([result.scheme, result.netloc])
except:
return False
def is_internet_connected():
try:
response = requests.head("https://www.google.com")
return response.status_code == 200
except:
return False
def convert_image_to_webp(image_bytes):
"""Convert image bytes to webp format for faster loading"""
image_io = io.BytesIO(image_bytes)
with Image.open(image_io) as original_image:
webp_image_io = io.BytesIO()
original_image.save(webp_image_io, "WEBP")
# Encode the WebP image back to base64
webp_image_bytes = webp_image_io.getvalue()
webp_image_io.close()
return webp_image_bytes
@lru_cache
def tz_to_cc_map() -> dict[str, str]:
"""Create a mapping of timezone to country code"""
timezone_country = {}
for countrycode in country_timezones:
timezones = country_timezones[countrycode]
for timezone in timezones:
timezone_country[timezone] = countrycode
return timezone_country
def get_country_code_from_timezone(tz: str) -> str:
"""Get country code from timezone"""
return tz_to_cc_map().get(tz, "US")
def get_country_name_from_timezone(tz: str) -> str:
"""Get country name from timezone"""
return country_names.get(get_country_code_from_timezone(tz), "United States")