diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 986724be..c52a4769 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -104,7 +104,11 @@ def anthropic_completion_with_backoff( # Calculate cost of chat input_tokens = final_message.usage.input_tokens output_tokens = final_message.usage.output_tokens - tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) + cache_read_tokens = final_message.usage.cache_read_input_tokens + cache_write_tokens = final_message.usage.cache_creation_input_tokens + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, tracer.get("usage") + ) # Save conversation trace tracer["chat_model"] = model_name @@ -207,7 +211,11 @@ def anthropic_llm_thread( # Calculate cost of chat input_tokens = final_message.usage.input_tokens output_tokens = final_message.usage.output_tokens - tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) + cache_read_tokens = final_message.usage.cache_read_input_tokens + cache_write_tokens = final_message.usage.cache_creation_input_tokens + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, tracer.get("usage") + ) # Save conversation trace tracer["chat_model"] = model_name diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index ff141c75..d63010b2 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -109,7 +109,7 @@ def gemini_completion_with_backoff( # Aggregate cost of chat input_tokens = response.usage_metadata.prompt_token_count if response else 0 output_tokens = response.usage_metadata.candidates_token_count if response else 0 - tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, usage=tracer.get("usage")) # Save conversation trace tracer["chat_model"] = model_name @@ -191,7 +191,7 @@ def gemini_llm_thread( # Calculate cost of chat input_tokens = chunk.usage_metadata.prompt_token_count output_tokens = chunk.usage_metadata.candidates_token_count - tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, usage=tracer.get("usage")) # Save conversation trace tracer["chat_model"] = model_name diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 3037270e..7d1f114d 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -93,7 +93,9 @@ def completion_with_backoff( chunk.usage.model_extra.get("estimated_cost", 0) if hasattr(chunk, "usage") and chunk.usage else 0 ) # Estimated costs returned by DeepInfra API - tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"), cost) + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost + ) # Save conversation trace tracer["chat_model"] = model_name @@ -226,7 +228,9 @@ def llm_thread( cost = ( chunk.usage.model_extra.get("estimated_cost", 0) if hasattr(chunk, "usage") and chunk.usage else 0 ) # Estimated costs returned by DeepInfra API - tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"), cost) + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost + ) # Save conversation trace tracer["chat_model"] = model_name diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index 4f5c1cc8..d9a2785c 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -47,12 +47,12 @@ model_to_cost: Dict[str, Dict[str, float]] = { "gemini-1.5-pro": {"input": 1.25, "output": 5.00}, "gemini-1.5-pro-002": {"input": 1.25, "output": 5.00}, "gemini-2.0-flash": {"input": 0.10, "output": 0.40}, - # Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_ - "claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0}, - "claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0}, - "claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0}, - "claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0}, - "claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0}, - "claude-3-7-sonnet@20250219": {"input": 3.0, "output": 15.0}, - "claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0}, + # Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api + "claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0, "cache_read": 0.08, "cache_write": 1.0}, + "claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0, "cache_read": 0.08, "cache_write": 1.0}, + "claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75}, + "claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75}, + "claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75}, + "claude-3-7-sonnet@20250219": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75}, + "claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0, "cache_read": 0.3, "cache_write": 3.75}, } diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index a3042daf..c259dc20 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -596,7 +596,14 @@ def get_country_name_from_timezone(tz: str) -> str: return country_names.get(get_country_code_from_timezone(tz), "United States") -def get_cost_of_chat_message(model_name: str, input_tokens: int = 0, output_tokens: int = 0, prev_cost: float = 0.0): +def get_cost_of_chat_message( + model_name: str, + input_tokens: int = 0, + output_tokens: int = 0, + cache_read_tokens: int = 0, + cache_write_tokens: int = 0, + prev_cost: float = 0.0, +): """ Calculate cost of chat message based on input and output tokens """ @@ -604,21 +611,40 @@ def get_cost_of_chat_message(model_name: str, input_tokens: int = 0, output_toke # Calculate cost of input and output tokens. Costs are per million tokens input_cost = constants.model_to_cost.get(model_name, {}).get("input", 0) * (input_tokens / 1e6) output_cost = constants.model_to_cost.get(model_name, {}).get("output", 0) * (output_tokens / 1e6) + cache_read_cost = constants.model_to_cost.get(model_name, {}).get("cache_read", 0) * (cache_read_tokens / 1e6) + cache_write_cost = constants.model_to_cost.get(model_name, {}).get("cache_write", 0) * (cache_write_tokens / 1e6) - return input_cost + output_cost + prev_cost + return input_cost + output_cost + cache_read_cost + cache_write_cost + prev_cost def get_chat_usage_metrics( - model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}, cost: float = None + model_name: str, + input_tokens: int = 0, + output_tokens: int = 0, + cache_read_tokens: int = 0, + cache_write_tokens: int = 0, + usage: dict = {}, + cost: float = None, ): """ Get usage metrics for chat message based on input and output tokens and cost """ - prev_usage = usage or {"input_tokens": 0, "output_tokens": 0, "cost": 0.0} + prev_usage = usage or { + "input_tokens": 0, + "output_tokens": 0, + "cache_read_tokens": 0, + "cache_write_tokens": 0, + "cost": 0.0, + } return { "input_tokens": prev_usage["input_tokens"] + input_tokens, "output_tokens": prev_usage["output_tokens"] + output_tokens, - "cost": cost or get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]), + "cache_read_tokens": prev_usage.get("cache_read_tokens", 0) + cache_read_tokens, + "cache_write_tokens": prev_usage.get("cache_write_tokens", 0) + cache_write_tokens, + "cost": cost + or get_cost_of_chat_message( + model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, prev_cost=prev_usage["cost"] + ), }