diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 009709fd..2f074b2e 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -14,11 +14,9 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import PIL.Image import pyjson5 import requests -import tiktoken import yaml from langchain_core.messages.chat import ChatMessage from pydantic import BaseModel, ConfigDict, ValidationError -from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from khoj.database.adapters import ConversationAdapters from khoj.database.models import ( @@ -32,9 +30,10 @@ from khoj.search_filter.base_filter import BaseFilter from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter -from khoj.utils import state from khoj.utils.helpers import ( ConversationCommand, + count_tokens, + get_encoder, is_none_or_empty, is_promptrace_enabled, merge_dicts, @@ -724,72 +723,6 @@ def generate_chatml_messages_with_context( return messages -def get_encoder( - model_name: str, - tokenizer_name=None, -) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast: - default_tokenizer = "gpt-4o" - - try: - if tokenizer_name: - if tokenizer_name in state.pretrained_tokenizers: - encoder = state.pretrained_tokenizers[tokenizer_name] - else: - encoder = AutoTokenizer.from_pretrained(tokenizer_name) - state.pretrained_tokenizers[tokenizer_name] = encoder - else: - # as tiktoken doesn't recognize o1 model series yet - encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name) - except Exception: - encoder = tiktoken.encoding_for_model(default_tokenizer) - if state.verbose > 2: - logger.debug( - f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for model: {model_name} in Khoj settings to improve context stuffing." - ) - return encoder - - -def count_tokens( - message_content: str | list[str | dict], - encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding, -) -> int: - """ - Count the total number of tokens in a list of messages. - - Assumes each images takes 500 tokens for approximation. - """ - if isinstance(message_content, list): - image_count = 0 - message_content_parts: list[str] = [] - # Collate message content into single string to ease token counting - for part in message_content: - if isinstance(part, dict) and part.get("type") == "image_url": - image_count += 1 - elif isinstance(part, dict) and part.get("type") == "text": - message_content_parts.append(part["text"]) - elif isinstance(part, dict) and hasattr(part, "model_dump"): - message_content_parts.append(json.dumps(part.model_dump())) - elif isinstance(part, dict) and hasattr(part, "__dict__"): - message_content_parts.append(json.dumps(part.__dict__)) - elif isinstance(part, dict): - # If part is a dict but not a recognized type, convert to JSON string - try: - message_content_parts.append(json.dumps(part)) - except (TypeError, ValueError) as e: - logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.") - image_count += 1 # Treat as an image/binary if serialization fails - elif isinstance(part, str): - message_content_parts.append(part) - else: - logger.warning(f"Unknown message type: {part}. Skipping.") - message_content = "\n".join(message_content_parts).rstrip() - return len(encoder.encode(message_content)) + image_count * 500 - elif isinstance(message_content, str): - return len(encoder.encode(message_content)) - else: - return len(encoder.encode(json.dumps(message_content))) - - def count_total_tokens( messages: list[ChatMessage], encoder, system_message: Optional[list[ChatMessage]] = None ) -> Tuple[int, int]: diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 4de466fd..1afa9675 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -5,6 +5,7 @@ import copy import datetime import io import ipaddress +import json import logging import os import platform @@ -31,6 +32,7 @@ import openai import psutil import pyjson5 import requests +import tiktoken import torch from asgiref.sync import sync_to_async from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email @@ -41,6 +43,7 @@ from magika import Magika from PIL import Image from pydantic import BaseModel from pytz import country_names, country_timezones +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from khoj.utils import constants @@ -976,6 +979,64 @@ def get_cost_of_chat_message( return input_cost + output_cost + thought_cost + cache_read_cost + cache_write_cost + prev_cost +def get_encoder( + model_name: str, + tokenizer_name=None, +) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast: + default_tokenizer = "gpt-4o" + + try: + if tokenizer_name: + encoder = AutoTokenizer.from_pretrained(tokenizer_name) + else: + # as tiktoken doesn't recognize o1 model series yet + encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name) + except Exception: + encoder = tiktoken.encoding_for_model(default_tokenizer) + return encoder + + +def count_tokens( + message_content: str | list[str | dict], + encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding, +) -> int: + """ + Count the total number of tokens in a list of messages. + + Assumes each images takes 500 tokens for approximation. + """ + if isinstance(message_content, list): + image_count = 0 + message_content_parts: list[str] = [] + # Collate message content into single string to ease token counting + for part in message_content: + if isinstance(part, dict) and part.get("type") == "image_url": + image_count += 1 + elif isinstance(part, dict) and part.get("type") == "text": + message_content_parts.append(part["text"]) + elif isinstance(part, dict) and hasattr(part, "model_dump"): + message_content_parts.append(json.dumps(part.model_dump())) + elif isinstance(part, dict) and hasattr(part, "__dict__"): + message_content_parts.append(json.dumps(part.__dict__)) + elif isinstance(part, dict): + # If part is a dict but not a recognized type, convert to JSON string + try: + message_content_parts.append(json.dumps(part)) + except (TypeError, ValueError) as e: + logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.") + image_count += 1 # Treat as an image/binary if serialization fails + elif isinstance(part, str): + message_content_parts.append(part) + else: + logger.warning(f"Unknown message type: {part}. Skipping.") + message_content = "\n".join(message_content_parts).rstrip() + return len(encoder.encode(message_content)) + image_count * 500 + elif isinstance(message_content, str): + return len(encoder.encode(message_content)) + else: + return len(encoder.encode(json.dumps(message_content))) + + def get_chat_usage_metrics( model_name: str, input_tokens: int = 0, diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 3958173d..f0005093 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -6,7 +6,6 @@ from typing import Dict, List from apscheduler.schedulers.background import BackgroundScheduler from openai import OpenAI -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from whisper import Whisper from khoj.database.models import ProcessLock @@ -35,7 +34,6 @@ telemetry_disabled: bool = is_env_var_true("KHOJ_TELEMETRY_DISABLE") khoj_version: str = None device = get_device() anonymous_mode: bool = False -pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict() billing_enabled: bool = ( os.getenv("STRIPE_API_KEY") is not None and os.getenv("STRIPE_SIGNING_SECRET") is not None