diff --git a/pyproject.toml b/pyproject.toml index e59c29fc..b8e6c6ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ dependencies = [ "django_apscheduler == 0.7.0", "anthropic == 0.49.0", "docx2txt == 0.8", - "google-genai == 1.5.0", + "google-genai == 1.11.0", "google-auth ~= 2.23.3", "pyjson5 == 1.6.7", "resend == 1.0.1", diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 456943aa..6c2ffb8a 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -108,7 +108,7 @@ def anthropic_completion_with_backoff( cache_read_tokens = final_message.usage.cache_read_input_tokens cache_write_tokens = final_message.usage.cache_creation_input_tokens tracer["usage"] = get_chat_usage_metrics( - model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, tracer.get("usage") + model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage") ) # Save conversation trace @@ -213,7 +213,7 @@ def anthropic_llm_thread( cache_read_tokens = final_message.usage.cache_read_input_tokens cache_write_tokens = final_message.usage.cache_creation_input_tokens tracer["usage"] = get_chat_usage_metrics( - model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, tracer.get("usage") + model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage") ) # Save conversation trace diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index ba0b93f7..73167ca2 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -133,6 +133,7 @@ def gemini_send_message_to_model( response_type="text", response_schema=None, model_kwargs=None, + deepthought=False, tracer={}, ): """ @@ -154,6 +155,7 @@ def gemini_send_message_to_model( api_key=api_key, api_base_url=api_base_url, model_kwargs=model_kwargs, + deepthought=deepthought, tracer=tracer, ) @@ -181,6 +183,7 @@ def converse_gemini( generated_files: List[FileAttachment] = None, generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = None, + deepthought: Optional[bool] = False, tracer={}, ): """ @@ -260,5 +263,6 @@ def converse_gemini( api_base_url=api_base_url, system_prompt=system_prompt, completion_func=completion_func, + deepthought=deepthought, tracer=tracer, ) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 98b87e8a..9a8b4132 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -36,6 +36,8 @@ logger = logging.getLogger(__name__) gemini_clients: Dict[str, genai.Client] = {} MAX_OUTPUT_TOKENS_GEMINI = 8192 +MAX_REASONING_TOKENS_GEMINI = 10000 + SAFETY_SETTINGS = [ gtypes.SafetySetting( category=gtypes.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, @@ -78,7 +80,15 @@ def get_gemini_client(api_key, api_base_url=None) -> genai.Client: reraise=True, ) def gemini_completion_with_backoff( - messages, system_prompt, model_name, temperature=1.0, api_key=None, api_base_url=None, model_kwargs=None, tracer={} + messages, + system_prompt, + model_name: str, + temperature=1.0, + api_key=None, + api_base_url: str = None, + model_kwargs=None, + deepthought=False, + tracer={}, ) -> str: client = gemini_clients.get(api_key) if not client: @@ -92,10 +102,15 @@ def gemini_completion_with_backoff( if model_kwargs and model_kwargs.get("response_schema"): response_schema = clean_response_schema(model_kwargs["response_schema"]) + thinking_config = None + if deepthought and model_name.startswith("gemini-2-5"): + thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI) + seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None config = gtypes.GenerateContentConfig( system_instruction=system_prompt, temperature=temperature, + thinking_config=thinking_config, max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI, safety_settings=SAFETY_SETTINGS, response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain", @@ -119,7 +134,10 @@ def gemini_completion_with_backoff( # Aggregate cost of chat input_tokens = response.usage_metadata.prompt_token_count if response else 0 output_tokens = response.usage_metadata.candidates_token_count if response else 0 - tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, usage=tracer.get("usage")) + thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0 + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") + ) # Save conversation trace tracer["chat_model"] = model_name @@ -147,12 +165,24 @@ def gemini_chat_completion_with_backoff( system_prompt, completion_func=None, model_kwargs=None, + deepthought=False, tracer: dict = {}, ): g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( target=gemini_llm_thread, - args=(g, messages, system_prompt, model_name, temperature, api_key, api_base_url, model_kwargs, tracer), + args=( + g, + messages, + system_prompt, + model_name, + temperature, + api_key, + api_base_url, + model_kwargs, + deepthought, + tracer, + ), ) t.start() return g @@ -167,6 +197,7 @@ def gemini_llm_thread( api_key, api_base_url=None, model_kwargs=None, + deepthought=False, tracer: dict = {}, ): try: @@ -177,10 +208,15 @@ def gemini_llm_thread( formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt) + thinking_config = None + if deepthought and model_name.startswith("gemini-2-5"): + thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI) + seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None config = gtypes.GenerateContentConfig( system_instruction=system_prompt, temperature=temperature, + thinking_config=thinking_config, max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI, stop_sequences=["Notes:\n["], safety_settings=SAFETY_SETTINGS, @@ -202,7 +238,10 @@ def gemini_llm_thread( # Calculate cost of chat input_tokens = chunk.usage_metadata.prompt_token_count output_tokens = chunk.usage_metadata.candidates_token_count - tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, usage=tracer.get("usage")) + thought_tokens = chunk.usage_metadata.thoughts_token_count or 0 + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") + ) # Save conversation trace tracer["chat_model"] = model_name diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 88747e46..5779fab6 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1290,6 +1290,7 @@ async def send_message_to_model_wrapper( model=chat_model_name, response_type=response_type, response_schema=response_schema, + deepthought=deepthought, api_base_url=api_base_url, tracer=tracer, ) @@ -1593,6 +1594,7 @@ def generate_chat_response( generated_files=raw_generated_files, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, + deepthought=deepthought, tracer=tracer, ) diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index 87767162..daf8469f 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -49,6 +49,8 @@ model_to_cost: Dict[str, Dict[str, float]] = { "gemini-1.5-pro": {"input": 1.25, "output": 5.00}, "gemini-1.5-pro-002": {"input": 1.25, "output": 5.00}, "gemini-2.0-flash": {"input": 0.10, "output": 0.40}, + "gemini-2.5-flash-preview-04-17": {"input": 0.15, "output": 0.60, "thought": 3.50}, + "gemini-2.5-pro-preview-03-25": {"input": 1.25, "output": 10.0}, # Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api "claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0, "cache_read": 0.08, "cache_write": 1.0}, "claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0, "cache_read": 0.08, "cache_write": 1.0}, diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index c990b70b..f0aa0cd6 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -601,6 +601,7 @@ def get_cost_of_chat_message( model_name: str, input_tokens: int = 0, output_tokens: int = 0, + thought_tokens: int = 0, cache_read_tokens: int = 0, cache_write_tokens: int = 0, prev_cost: float = 0.0, @@ -612,10 +613,11 @@ def get_cost_of_chat_message( # Calculate cost of input and output tokens. Costs are per million tokens input_cost = constants.model_to_cost.get(model_name, {}).get("input", 0) * (input_tokens / 1e6) output_cost = constants.model_to_cost.get(model_name, {}).get("output", 0) * (output_tokens / 1e6) + thought_cost = constants.model_to_cost.get(model_name, {}).get("thought", 0) * (thought_tokens / 1e6) cache_read_cost = constants.model_to_cost.get(model_name, {}).get("cache_read", 0) * (cache_read_tokens / 1e6) cache_write_cost = constants.model_to_cost.get(model_name, {}).get("cache_write", 0) * (cache_write_tokens / 1e6) - return input_cost + output_cost + cache_read_cost + cache_write_cost + prev_cost + return input_cost + output_cost + thought_cost + cache_read_cost + cache_write_cost + prev_cost def get_chat_usage_metrics( @@ -624,6 +626,7 @@ def get_chat_usage_metrics( output_tokens: int = 0, cache_read_tokens: int = 0, cache_write_tokens: int = 0, + thought_tokens: int = 0, usage: dict = {}, cost: float = None, ): @@ -633,6 +636,7 @@ def get_chat_usage_metrics( prev_usage = usage or { "input_tokens": 0, "output_tokens": 0, + "thought_tokens": 0, "cache_read_tokens": 0, "cache_write_tokens": 0, "cost": 0.0, @@ -640,11 +644,18 @@ def get_chat_usage_metrics( return { "input_tokens": prev_usage["input_tokens"] + input_tokens, "output_tokens": prev_usage["output_tokens"] + output_tokens, + "thought_tokens": prev_usage.get("thought_tokens", 0) + thought_tokens, "cache_read_tokens": prev_usage.get("cache_read_tokens", 0) + cache_read_tokens, "cache_write_tokens": prev_usage.get("cache_write_tokens", 0) + cache_write_tokens, "cost": cost or get_cost_of_chat_message( - model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, prev_cost=prev_usage["cost"] + model_name, + input_tokens, + output_tokens, + thought_tokens, + cache_read_tokens, + cache_write_tokens, + prev_cost=prev_usage["cost"], ), }