mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Refactor count_tokens, get_encoder methods to utils/helper.py
Simplify get_encoder to not rely on global state. The caching simplification is not necessary for now.
This commit is contained in:
@@ -14,11 +14,9 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
|||||||
import PIL.Image
|
import PIL.Image
|
||||||
import pyjson5
|
import pyjson5
|
||||||
import requests
|
import requests
|
||||||
import tiktoken
|
|
||||||
import yaml
|
import yaml
|
||||||
from langchain_core.messages.chat import ChatMessage
|
from langchain_core.messages.chat import ChatMessage
|
||||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters
|
from khoj.database.adapters import ConversationAdapters
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
@@ -32,9 +30,10 @@ from khoj.search_filter.base_filter import BaseFilter
|
|||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.utils import state
|
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
ConversationCommand,
|
ConversationCommand,
|
||||||
|
count_tokens,
|
||||||
|
get_encoder,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
is_promptrace_enabled,
|
is_promptrace_enabled,
|
||||||
merge_dicts,
|
merge_dicts,
|
||||||
@@ -724,72 +723,6 @@ def generate_chatml_messages_with_context(
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def get_encoder(
|
|
||||||
model_name: str,
|
|
||||||
tokenizer_name=None,
|
|
||||||
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast:
|
|
||||||
default_tokenizer = "gpt-4o"
|
|
||||||
|
|
||||||
try:
|
|
||||||
if tokenizer_name:
|
|
||||||
if tokenizer_name in state.pretrained_tokenizers:
|
|
||||||
encoder = state.pretrained_tokenizers[tokenizer_name]
|
|
||||||
else:
|
|
||||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
||||||
state.pretrained_tokenizers[tokenizer_name] = encoder
|
|
||||||
else:
|
|
||||||
# as tiktoken doesn't recognize o1 model series yet
|
|
||||||
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
|
|
||||||
except Exception:
|
|
||||||
encoder = tiktoken.encoding_for_model(default_tokenizer)
|
|
||||||
if state.verbose > 2:
|
|
||||||
logger.debug(
|
|
||||||
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for model: {model_name} in Khoj settings to improve context stuffing."
|
|
||||||
)
|
|
||||||
return encoder
|
|
||||||
|
|
||||||
|
|
||||||
def count_tokens(
|
|
||||||
message_content: str | list[str | dict],
|
|
||||||
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Count the total number of tokens in a list of messages.
|
|
||||||
|
|
||||||
Assumes each images takes 500 tokens for approximation.
|
|
||||||
"""
|
|
||||||
if isinstance(message_content, list):
|
|
||||||
image_count = 0
|
|
||||||
message_content_parts: list[str] = []
|
|
||||||
# Collate message content into single string to ease token counting
|
|
||||||
for part in message_content:
|
|
||||||
if isinstance(part, dict) and part.get("type") == "image_url":
|
|
||||||
image_count += 1
|
|
||||||
elif isinstance(part, dict) and part.get("type") == "text":
|
|
||||||
message_content_parts.append(part["text"])
|
|
||||||
elif isinstance(part, dict) and hasattr(part, "model_dump"):
|
|
||||||
message_content_parts.append(json.dumps(part.model_dump()))
|
|
||||||
elif isinstance(part, dict) and hasattr(part, "__dict__"):
|
|
||||||
message_content_parts.append(json.dumps(part.__dict__))
|
|
||||||
elif isinstance(part, dict):
|
|
||||||
# If part is a dict but not a recognized type, convert to JSON string
|
|
||||||
try:
|
|
||||||
message_content_parts.append(json.dumps(part))
|
|
||||||
except (TypeError, ValueError) as e:
|
|
||||||
logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.")
|
|
||||||
image_count += 1 # Treat as an image/binary if serialization fails
|
|
||||||
elif isinstance(part, str):
|
|
||||||
message_content_parts.append(part)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unknown message type: {part}. Skipping.")
|
|
||||||
message_content = "\n".join(message_content_parts).rstrip()
|
|
||||||
return len(encoder.encode(message_content)) + image_count * 500
|
|
||||||
elif isinstance(message_content, str):
|
|
||||||
return len(encoder.encode(message_content))
|
|
||||||
else:
|
|
||||||
return len(encoder.encode(json.dumps(message_content)))
|
|
||||||
|
|
||||||
|
|
||||||
def count_total_tokens(
|
def count_total_tokens(
|
||||||
messages: list[ChatMessage], encoder, system_message: Optional[list[ChatMessage]] = None
|
messages: list[ChatMessage], encoder, system_message: Optional[list[ChatMessage]] = None
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import copy
|
|||||||
import datetime
|
import datetime
|
||||||
import io
|
import io
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
@@ -31,6 +32,7 @@ import openai
|
|||||||
import psutil
|
import psutil
|
||||||
import pyjson5
|
import pyjson5
|
||||||
import requests
|
import requests
|
||||||
|
import tiktoken
|
||||||
import torch
|
import torch
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
||||||
@@ -41,6 +43,7 @@ from magika import Magika
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pytz import country_names, country_timezones
|
from pytz import country_names, country_timezones
|
||||||
|
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
|
||||||
from khoj.utils import constants
|
from khoj.utils import constants
|
||||||
|
|
||||||
@@ -976,6 +979,64 @@ def get_cost_of_chat_message(
|
|||||||
return input_cost + output_cost + thought_cost + cache_read_cost + cache_write_cost + prev_cost
|
return input_cost + output_cost + thought_cost + cache_read_cost + cache_write_cost + prev_cost
|
||||||
|
|
||||||
|
|
||||||
|
def get_encoder(
|
||||||
|
model_name: str,
|
||||||
|
tokenizer_name=None,
|
||||||
|
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast:
|
||||||
|
default_tokenizer = "gpt-4o"
|
||||||
|
|
||||||
|
try:
|
||||||
|
if tokenizer_name:
|
||||||
|
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||||
|
else:
|
||||||
|
# as tiktoken doesn't recognize o1 model series yet
|
||||||
|
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
|
||||||
|
except Exception:
|
||||||
|
encoder = tiktoken.encoding_for_model(default_tokenizer)
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
def count_tokens(
|
||||||
|
message_content: str | list[str | dict],
|
||||||
|
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count the total number of tokens in a list of messages.
|
||||||
|
|
||||||
|
Assumes each images takes 500 tokens for approximation.
|
||||||
|
"""
|
||||||
|
if isinstance(message_content, list):
|
||||||
|
image_count = 0
|
||||||
|
message_content_parts: list[str] = []
|
||||||
|
# Collate message content into single string to ease token counting
|
||||||
|
for part in message_content:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "image_url":
|
||||||
|
image_count += 1
|
||||||
|
elif isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
message_content_parts.append(part["text"])
|
||||||
|
elif isinstance(part, dict) and hasattr(part, "model_dump"):
|
||||||
|
message_content_parts.append(json.dumps(part.model_dump()))
|
||||||
|
elif isinstance(part, dict) and hasattr(part, "__dict__"):
|
||||||
|
message_content_parts.append(json.dumps(part.__dict__))
|
||||||
|
elif isinstance(part, dict):
|
||||||
|
# If part is a dict but not a recognized type, convert to JSON string
|
||||||
|
try:
|
||||||
|
message_content_parts.append(json.dumps(part))
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
logger.warning(f"Failed to serialize part {part} to JSON: {e}. Skipping.")
|
||||||
|
image_count += 1 # Treat as an image/binary if serialization fails
|
||||||
|
elif isinstance(part, str):
|
||||||
|
message_content_parts.append(part)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown message type: {part}. Skipping.")
|
||||||
|
message_content = "\n".join(message_content_parts).rstrip()
|
||||||
|
return len(encoder.encode(message_content)) + image_count * 500
|
||||||
|
elif isinstance(message_content, str):
|
||||||
|
return len(encoder.encode(message_content))
|
||||||
|
else:
|
||||||
|
return len(encoder.encode(json.dumps(message_content)))
|
||||||
|
|
||||||
|
|
||||||
def get_chat_usage_metrics(
|
def get_chat_usage_metrics(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
input_tokens: int = 0,
|
input_tokens: int = 0,
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import Dict, List
|
|||||||
|
|
||||||
from apscheduler.schedulers.background import BackgroundScheduler
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
||||||
from whisper import Whisper
|
from whisper import Whisper
|
||||||
|
|
||||||
from khoj.database.models import ProcessLock
|
from khoj.database.models import ProcessLock
|
||||||
@@ -35,7 +34,6 @@ telemetry_disabled: bool = is_env_var_true("KHOJ_TELEMETRY_DISABLE")
|
|||||||
khoj_version: str = None
|
khoj_version: str = None
|
||||||
device = get_device()
|
device = get_device()
|
||||||
anonymous_mode: bool = False
|
anonymous_mode: bool = False
|
||||||
pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict()
|
|
||||||
billing_enabled: bool = (
|
billing_enabled: bool = (
|
||||||
os.getenv("STRIPE_API_KEY") is not None
|
os.getenv("STRIPE_API_KEY") is not None
|
||||||
and os.getenv("STRIPE_SIGNING_SECRET") is not None
|
and os.getenv("STRIPE_SIGNING_SECRET") is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user