mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Refactor count_tokens, get_encoder methods to utils/helper.py
Simplify get_encoder to not rely on global state. The caching simplification is not necessary for now.
This commit is contained in:
@@ -14,11 +14,9 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
import PIL.Image
|
||||
import pyjson5
|
||||
import requests
|
||||
import tiktoken
|
||||
import yaml
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import (
|
||||
@@ -32,9 +30,10 @@ from khoj.search_filter.base_filter import BaseFilter
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
count_tokens,
|
||||
get_encoder,
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
merge_dicts,
|
||||
@@ -724,72 +723,6 @@ def generate_chatml_messages_with_context(
|
||||
return messages
|
||||
|
||||
|
||||
def get_encoder(
|
||||
model_name: str,
|
||||
tokenizer_name=None,
|
||||
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast:
|
||||
default_tokenizer = "gpt-4o"
|
||||
|
||||
try:
|
||||
if tokenizer_name:
|
||||
if tokenizer_name in state.pretrained_tokenizers:
|
||||
encoder = state.pretrained_tokenizers[tokenizer_name]
|
||||
else:
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
state.pretrained_tokenizers[tokenizer_name] = encoder
|
||||
else:
|
||||
# as tiktoken doesn't recognize o1 model series yet
|
||||
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
|
||||
except Exception:
|
||||
encoder = tiktoken.encoding_for_model(default_tokenizer)
|
||||
if state.verbose > 2:
|
||||
logger.debug(
|
||||
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for model: {model_name} in Khoj settings to improve context stuffing."
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
||||
def count_tokens(
|
||||
message_content: str | list[str | dict],
|
||||
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding,
|
||||
) -> int:
|
||||
"""
|
||||
Count the total number of tokens in a list of messages.
|
||||
|
||||
Assumes each images takes 500 tokens for approximation.
|
||||
"""
|
||||
if isinstance(message_content, list):
|
||||
image_count = 0
|
||||
message_content_parts: list[str] = []
|
||||
# Collate message content into single string to ease token counting
|
||||
for part in message_content:
|
||||
if isinstance(part, dict) and part.get("type") == "image_url":
|
||||
image_count += 1
|
||||
elif isinstance(part, dict) and part.get("type") == "text":
|
||||
message_content_parts.append(part["text"])
|
||||
elif isinstance(part, dict) and hasattr(part, "model_dump"):
|
||||
message_content_parts.append(json.dumps(part.model_dump()))
|
||||
elif isinstance(part, dict) and hasattr(part, "__dict__"):
|
||||
message_content_parts.append(json.dumps(part.__dict__))
|
||||
elif isinstance(part, dict):
|
||||
# If part is a dict but not a recognized type, convert to JSON string
|
||||
try:
|
||||
message_content_parts.append(json.dumps(part))
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.")
|
||||
image_count += 1 # Treat as an image/binary if serialization fails
|
||||
elif isinstance(part, str):
|
||||
message_content_parts.append(part)
|
||||
else:
|
||||
logger.warning(f"Unknown message type: {part}. Skipping.")
|
||||
message_content = "\n".join(message_content_parts).rstrip()
|
||||
return len(encoder.encode(message_content)) + image_count * 500
|
||||
elif isinstance(message_content, str):
|
||||
return len(encoder.encode(message_content))
|
||||
else:
|
||||
return len(encoder.encode(json.dumps(message_content)))
|
||||
|
||||
|
||||
def count_total_tokens(
|
||||
messages: list[ChatMessage], encoder, system_message: Optional[list[ChatMessage]] = None
|
||||
) -> Tuple[int, int]:
|
||||
|
||||
@@ -5,6 +5,7 @@ import copy
|
||||
import datetime
|
||||
import io
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
@@ -31,6 +32,7 @@ import openai
|
||||
import psutil
|
||||
import pyjson5
|
||||
import requests
|
||||
import tiktoken
|
||||
import torch
|
||||
from asgiref.sync import sync_to_async
|
||||
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
||||
@@ -41,6 +43,7 @@ from magika import Magika
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
from pytz import country_names, country_timezones
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from khoj.utils import constants
|
||||
|
||||
@@ -976,6 +979,64 @@ def get_cost_of_chat_message(
|
||||
return input_cost + output_cost + thought_cost + cache_read_cost + cache_write_cost + prev_cost
|
||||
|
||||
|
||||
def get_encoder(
|
||||
model_name: str,
|
||||
tokenizer_name=None,
|
||||
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast:
|
||||
default_tokenizer = "gpt-4o"
|
||||
|
||||
try:
|
||||
if tokenizer_name:
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
else:
|
||||
# as tiktoken doesn't recognize o1 model series yet
|
||||
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
|
||||
except Exception:
|
||||
encoder = tiktoken.encoding_for_model(default_tokenizer)
|
||||
return encoder
|
||||
|
||||
|
||||
def count_tokens(
|
||||
message_content: str | list[str | dict],
|
||||
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding,
|
||||
) -> int:
|
||||
"""
|
||||
Count the total number of tokens in a list of messages.
|
||||
|
||||
Assumes each images takes 500 tokens for approximation.
|
||||
"""
|
||||
if isinstance(message_content, list):
|
||||
image_count = 0
|
||||
message_content_parts: list[str] = []
|
||||
# Collate message content into single string to ease token counting
|
||||
for part in message_content:
|
||||
if isinstance(part, dict) and part.get("type") == "image_url":
|
||||
image_count += 1
|
||||
elif isinstance(part, dict) and part.get("type") == "text":
|
||||
message_content_parts.append(part["text"])
|
||||
elif isinstance(part, dict) and hasattr(part, "model_dump"):
|
||||
message_content_parts.append(json.dumps(part.model_dump()))
|
||||
elif isinstance(part, dict) and hasattr(part, "__dict__"):
|
||||
message_content_parts.append(json.dumps(part.__dict__))
|
||||
elif isinstance(part, dict):
|
||||
# If part is a dict but not a recognized type, convert to JSON string
|
||||
try:
|
||||
message_content_parts.append(json.dumps(part))
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.")
|
||||
image_count += 1 # Treat as an image/binary if serialization fails
|
||||
elif isinstance(part, str):
|
||||
message_content_parts.append(part)
|
||||
else:
|
||||
logger.warning(f"Unknown message type: {part}. Skipping.")
|
||||
message_content = "\n".join(message_content_parts).rstrip()
|
||||
return len(encoder.encode(message_content)) + image_count * 500
|
||||
elif isinstance(message_content, str):
|
||||
return len(encoder.encode(message_content))
|
||||
else:
|
||||
return len(encoder.encode(json.dumps(message_content)))
|
||||
|
||||
|
||||
def get_chat_usage_metrics(
|
||||
model_name: str,
|
||||
input_tokens: int = 0,
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Dict, List
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from openai import OpenAI
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from whisper import Whisper
|
||||
|
||||
from khoj.database.models import ProcessLock
|
||||
@@ -35,7 +34,6 @@ telemetry_disabled: bool = is_env_var_true("KHOJ_TELEMETRY_DISABLE")
|
||||
khoj_version: str = None
|
||||
device = get_device()
|
||||
anonymous_mode: bool = False
|
||||
pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict()
|
||||
billing_enabled: bool = (
|
||||
os.getenv("STRIPE_API_KEY") is not None
|
||||
and os.getenv("STRIPE_SIGNING_SECRET") is not None
|
||||
|
||||
Reference in New Issue
Block a user