Refactor count_tokens, get_encoder methods to utils/helper.py

Simplify get_encoder to not rely on global state. The caching
simplification is not necessary for now.
This commit is contained in:
Debanjum
2025-11-13 20:04:47 -08:00
parent 15482c54b5
commit d57c597245
3 changed files with 63 additions and 71 deletions

View File

@@ -14,11 +14,9 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import PIL.Image
import pyjson5
import requests
import tiktoken
import yaml
from langchain_core.messages.chat import ChatMessage
from pydantic import BaseModel, ConfigDict, ValidationError
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import (
@@ -32,9 +30,10 @@ from khoj.search_filter.base_filter import BaseFilter
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state
from khoj.utils.helpers import (
ConversationCommand,
count_tokens,
get_encoder,
is_none_or_empty,
is_promptrace_enabled,
merge_dicts,
@@ -724,72 +723,6 @@ def generate_chatml_messages_with_context(
return messages
def get_encoder(
model_name: str,
tokenizer_name=None,
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast:
default_tokenizer = "gpt-4o"
try:
if tokenizer_name:
if tokenizer_name in state.pretrained_tokenizers:
encoder = state.pretrained_tokenizers[tokenizer_name]
else:
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
state.pretrained_tokenizers[tokenizer_name] = encoder
else:
# as tiktoken doesn't recognize o1 model series yet
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
except Exception:
encoder = tiktoken.encoding_for_model(default_tokenizer)
if state.verbose > 2:
logger.debug(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for model: {model_name} in Khoj settings to improve context stuffing."
)
return encoder
def count_tokens(
message_content: str | list[str | dict],
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding,
) -> int:
"""
Count the total number of tokens in a list of messages.
Assumes each images takes 500 tokens for approximation.
"""
if isinstance(message_content, list):
image_count = 0
message_content_parts: list[str] = []
# Collate message content into single string to ease token counting
for part in message_content:
if isinstance(part, dict) and part.get("type") == "image_url":
image_count += 1
elif isinstance(part, dict) and part.get("type") == "text":
message_content_parts.append(part["text"])
elif isinstance(part, dict) and hasattr(part, "model_dump"):
message_content_parts.append(json.dumps(part.model_dump()))
elif isinstance(part, dict) and hasattr(part, "__dict__"):
message_content_parts.append(json.dumps(part.__dict__))
elif isinstance(part, dict):
# If part is a dict but not a recognized type, convert to JSON string
try:
message_content_parts.append(json.dumps(part))
except (TypeError, ValueError) as e:
logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.")
image_count += 1 # Treat as an image/binary if serialization fails
elif isinstance(part, str):
message_content_parts.append(part)
else:
logger.warning(f"Unknown message type: {part}. Skipping.")
message_content = "\n".join(message_content_parts).rstrip()
return len(encoder.encode(message_content)) + image_count * 500
elif isinstance(message_content, str):
return len(encoder.encode(message_content))
else:
return len(encoder.encode(json.dumps(message_content)))
def count_total_tokens(
messages: list[ChatMessage], encoder, system_message: Optional[list[ChatMessage]] = None
) -> Tuple[int, int]:

View File

@@ -5,6 +5,7 @@ import copy
import datetime
import io
import ipaddress
import json
import logging
import os
import platform
@@ -31,6 +32,7 @@ import openai
import psutil
import pyjson5
import requests
import tiktoken
import torch
from asgiref.sync import sync_to_async
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
@@ -41,6 +43,7 @@ from magika import Magika
from PIL import Image
from pydantic import BaseModel
from pytz import country_names, country_timezones
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from khoj.utils import constants
@@ -976,6 +979,64 @@ def get_cost_of_chat_message(
return input_cost + output_cost + thought_cost + cache_read_cost + cache_write_cost + prev_cost
def get_encoder(
model_name: str,
tokenizer_name=None,
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast:
default_tokenizer = "gpt-4o"
try:
if tokenizer_name:
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
else:
# as tiktoken doesn't recognize o1 model series yet
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
except Exception:
encoder = tiktoken.encoding_for_model(default_tokenizer)
return encoder
def count_tokens(
message_content: str | list[str | dict],
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding,
) -> int:
"""
Count the total number of tokens in a list of messages.
Assumes each images takes 500 tokens for approximation.
"""
if isinstance(message_content, list):
image_count = 0
message_content_parts: list[str] = []
# Collate message content into single string to ease token counting
for part in message_content:
if isinstance(part, dict) and part.get("type") == "image_url":
image_count += 1
elif isinstance(part, dict) and part.get("type") == "text":
message_content_parts.append(part["text"])
elif isinstance(part, dict) and hasattr(part, "model_dump"):
message_content_parts.append(json.dumps(part.model_dump()))
elif isinstance(part, dict) and hasattr(part, "__dict__"):
message_content_parts.append(json.dumps(part.__dict__))
elif isinstance(part, dict):
# If part is a dict but not a recognized type, convert to JSON string
try:
message_content_parts.append(json.dumps(part))
except (TypeError, ValueError) as e:
logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.")
image_count += 1 # Treat as an image/binary if serialization fails
elif isinstance(part, str):
message_content_parts.append(part)
else:
logger.warning(f"Unknown message type: {part}. Skipping.")
message_content = "\n".join(message_content_parts).rstrip()
return len(encoder.encode(message_content)) + image_count * 500
elif isinstance(message_content, str):
return len(encoder.encode(message_content))
else:
return len(encoder.encode(json.dumps(message_content)))
def get_chat_usage_metrics(
model_name: str,
input_tokens: int = 0,

View File

@@ -6,7 +6,6 @@ from typing import Dict, List
from apscheduler.schedulers.background import BackgroundScheduler
from openai import OpenAI
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from whisper import Whisper
from khoj.database.models import ProcessLock
@@ -35,7 +34,6 @@ telemetry_disabled: bool = is_env_var_true("KHOJ_TELEMETRY_DISABLE")
khoj_version: str = None
device = get_device()
anonymous_mode: bool = False
pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict()
billing_enabled: bool = (
os.getenv("STRIPE_API_KEY") is not None
and os.getenv("STRIPE_SIGNING_SECRET") is not None