Refactor count_tokens, get_encoder methods to utils/helper.py

Simplify get_encoder to not rely on global state. The caching
simplification is not necessary for now.
This commit is contained in:
Debanjum
2025-11-13 20:04:47 -08:00
parent 15482c54b5
commit d57c597245
3 changed files with 63 additions and 71 deletions

View File

@@ -14,11 +14,9 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import PIL.Image import PIL.Image
import pyjson5 import pyjson5
import requests import requests
import tiktoken
import yaml import yaml
from langchain_core.messages.chat import ChatMessage from langchain_core.messages.chat import ChatMessage
from pydantic import BaseModel, ConfigDict, ValidationError from pydantic import BaseModel, ConfigDict, ValidationError
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from khoj.database.adapters import ConversationAdapters from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ( from khoj.database.models import (
@@ -32,9 +30,10 @@ from khoj.search_filter.base_filter import BaseFilter
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state
from khoj.utils.helpers import ( from khoj.utils.helpers import (
ConversationCommand, ConversationCommand,
count_tokens,
get_encoder,
is_none_or_empty, is_none_or_empty,
is_promptrace_enabled, is_promptrace_enabled,
merge_dicts, merge_dicts,
@@ -724,72 +723,6 @@ def generate_chatml_messages_with_context(
return messages return messages
def get_encoder(
model_name: str,
tokenizer_name=None,
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast:
default_tokenizer = "gpt-4o"
try:
if tokenizer_name:
if tokenizer_name in state.pretrained_tokenizers:
encoder = state.pretrained_tokenizers[tokenizer_name]
else:
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
state.pretrained_tokenizers[tokenizer_name] = encoder
else:
# as tiktoken doesn't recognize o1 model series yet
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
except Exception:
encoder = tiktoken.encoding_for_model(default_tokenizer)
if state.verbose > 2:
logger.debug(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for model: {model_name} in Khoj settings to improve context stuffing."
)
return encoder
def count_tokens(
message_content: str | list[str | dict],
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding,
) -> int:
"""
Count the total number of tokens in a list of messages.
Assumes each images takes 500 tokens for approximation.
"""
if isinstance(message_content, list):
image_count = 0
message_content_parts: list[str] = []
# Collate message content into single string to ease token counting
for part in message_content:
if isinstance(part, dict) and part.get("type") == "image_url":
image_count += 1
elif isinstance(part, dict) and part.get("type") == "text":
message_content_parts.append(part["text"])
elif isinstance(part, dict) and hasattr(part, "model_dump"):
message_content_parts.append(json.dumps(part.model_dump()))
elif isinstance(part, dict) and hasattr(part, "__dict__"):
message_content_parts.append(json.dumps(part.__dict__))
elif isinstance(part, dict):
# If part is a dict but not a recognized type, convert to JSON string
try:
message_content_parts.append(json.dumps(part))
except (TypeError, ValueError) as e:
logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.")
image_count += 1 # Treat as an image/binary if serialization fails
elif isinstance(part, str):
message_content_parts.append(part)
else:
logger.warning(f"Unknown message type: {part}. Skipping.")
message_content = "\n".join(message_content_parts).rstrip()
return len(encoder.encode(message_content)) + image_count * 500
elif isinstance(message_content, str):
return len(encoder.encode(message_content))
else:
return len(encoder.encode(json.dumps(message_content)))
def count_total_tokens( def count_total_tokens(
messages: list[ChatMessage], encoder, system_message: Optional[list[ChatMessage]] = None messages: list[ChatMessage], encoder, system_message: Optional[list[ChatMessage]] = None
) -> Tuple[int, int]: ) -> Tuple[int, int]:

View File

@@ -5,6 +5,7 @@ import copy
import datetime import datetime
import io import io
import ipaddress import ipaddress
import json
import logging import logging
import os import os
import platform import platform
@@ -31,6 +32,7 @@ import openai
import psutil import psutil
import pyjson5 import pyjson5
import requests import requests
import tiktoken
import torch import torch
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
@@ -41,6 +43,7 @@ from magika import Magika
from PIL import Image from PIL import Image
from pydantic import BaseModel from pydantic import BaseModel
from pytz import country_names, country_timezones from pytz import country_names, country_timezones
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from khoj.utils import constants from khoj.utils import constants
@@ -976,6 +979,64 @@ def get_cost_of_chat_message(
return input_cost + output_cost + thought_cost + cache_read_cost + cache_write_cost + prev_cost return input_cost + output_cost + thought_cost + cache_read_cost + cache_write_cost + prev_cost
def get_encoder(
model_name: str,
tokenizer_name=None,
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast:
default_tokenizer = "gpt-4o"
try:
if tokenizer_name:
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
else:
# as tiktoken doesn't recognize o1 model series yet
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
except Exception:
encoder = tiktoken.encoding_for_model(default_tokenizer)
return encoder
def count_tokens(
message_content: str | list[str | dict],
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding,
) -> int:
"""
Count the total number of tokens in a list of messages.
Assumes each images takes 500 tokens for approximation.
"""
if isinstance(message_content, list):
image_count = 0
message_content_parts: list[str] = []
# Collate message content into single string to ease token counting
for part in message_content:
if isinstance(part, dict) and part.get("type") == "image_url":
image_count += 1
elif isinstance(part, dict) and part.get("type") == "text":
message_content_parts.append(part["text"])
elif isinstance(part, dict) and hasattr(part, "model_dump"):
message_content_parts.append(json.dumps(part.model_dump()))
elif isinstance(part, dict) and hasattr(part, "__dict__"):
message_content_parts.append(json.dumps(part.__dict__))
elif isinstance(part, dict):
# If part is a dict but not a recognized type, convert to JSON string
try:
message_content_parts.append(json.dumps(part))
except (TypeError, ValueError) as e:
logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.")
image_count += 1 # Treat as an image/binary if serialization fails
elif isinstance(part, str):
message_content_parts.append(part)
else:
logger.warning(f"Unknown message type: {part}. Skipping.")
message_content = "\n".join(message_content_parts).rstrip()
return len(encoder.encode(message_content)) + image_count * 500
elif isinstance(message_content, str):
return len(encoder.encode(message_content))
else:
return len(encoder.encode(json.dumps(message_content)))
def get_chat_usage_metrics( def get_chat_usage_metrics(
model_name: str, model_name: str,
input_tokens: int = 0, input_tokens: int = 0,

View File

@@ -6,7 +6,6 @@ from typing import Dict, List
from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.schedulers.background import BackgroundScheduler
from openai import OpenAI from openai import OpenAI
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from whisper import Whisper from whisper import Whisper
from khoj.database.models import ProcessLock from khoj.database.models import ProcessLock
@@ -35,7 +34,6 @@ telemetry_disabled: bool = is_env_var_true("KHOJ_TELEMETRY_DISABLE")
khoj_version: str = None khoj_version: str = None
device = get_device() device = get_device()
anonymous_mode: bool = False anonymous_mode: bool = False
pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict()
billing_enabled: bool = ( billing_enabled: bool = (
os.getenv("STRIPE_API_KEY") is not None os.getenv("STRIPE_API_KEY") is not None
and os.getenv("STRIPE_SIGNING_SECRET") is not None and os.getenv("STRIPE_SIGNING_SECRET") is not None