Support deepthought in research mode with new Gemini 2.5 reasoning model

The 2.5 flash model is the first hybrid reasoning models by Google

- Track costs of thoughts separately as they are priced differently
This commit is contained in:
Debanjum
2025-04-18 14:19:45 +05:30
parent f95173bb0a
commit eb1406bcb4
7 changed files with 67 additions and 9 deletions

View File

@@ -88,7 +88,7 @@ dependencies = [
"django_apscheduler == 0.7.0",
"anthropic == 0.49.0",
"docx2txt == 0.8",
"google-genai == 1.5.0",
"google-genai == 1.11.0",
"google-auth ~= 2.23.3",
"pyjson5 == 1.6.7",
"resend == 1.0.1",

View File

@@ -108,7 +108,7 @@ def anthropic_completion_with_backoff(
cache_read_tokens = final_message.usage.cache_read_input_tokens
cache_write_tokens = final_message.usage.cache_creation_input_tokens
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, tracer.get("usage")
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
)
# Save conversation trace
@@ -213,7 +213,7 @@ def anthropic_llm_thread(
cache_read_tokens = final_message.usage.cache_read_input_tokens
cache_write_tokens = final_message.usage.cache_creation_input_tokens
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, tracer.get("usage")
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
)
# Save conversation trace

View File

@@ -133,6 +133,7 @@ def gemini_send_message_to_model(
response_type="text",
response_schema=None,
model_kwargs=None,
deepthought=False,
tracer={},
):
"""
@@ -154,6 +155,7 @@ def gemini_send_message_to_model(
api_key=api_key,
api_base_url=api_base_url,
model_kwargs=model_kwargs,
deepthought=deepthought,
tracer=tracer,
)
@@ -181,6 +183,7 @@ def converse_gemini(
generated_files: List[FileAttachment] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = None,
deepthought: Optional[bool] = False,
tracer={},
):
"""
@@ -260,5 +263,6 @@ def converse_gemini(
api_base_url=api_base_url,
system_prompt=system_prompt,
completion_func=completion_func,
deepthought=deepthought,
tracer=tracer,
)

View File

@@ -36,6 +36,8 @@ logger = logging.getLogger(__name__)
gemini_clients: Dict[str, genai.Client] = {}
MAX_OUTPUT_TOKENS_GEMINI = 8192
MAX_REASONING_TOKENS_GEMINI = 10000
SAFETY_SETTINGS = [
gtypes.SafetySetting(
category=gtypes.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
@@ -78,7 +80,15 @@ def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
reraise=True,
)
def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=1.0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
messages,
system_prompt,
model_name: str,
temperature=1.0,
api_key=None,
api_base_url: str = None,
model_kwargs=None,
deepthought=False,
tracer={},
) -> str:
client = gemini_clients.get(api_key)
if not client:
@@ -92,10 +102,15 @@ def gemini_completion_with_backoff(
if model_kwargs and model_kwargs.get("response_schema"):
response_schema = clean_response_schema(model_kwargs["response_schema"])
thinking_config = None
if deepthought and model_name.startswith("gemini-2-5"):
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
temperature=temperature,
thinking_config=thinking_config,
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
safety_settings=SAFETY_SETTINGS,
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
@@ -119,7 +134,10 @@ def gemini_completion_with_backoff(
# Aggregate cost of chat
input_tokens = response.usage_metadata.prompt_token_count if response else 0
output_tokens = response.usage_metadata.candidates_token_count if response else 0
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, usage=tracer.get("usage"))
thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
)
# Save conversation trace
tracer["chat_model"] = model_name
@@ -147,12 +165,24 @@ def gemini_chat_completion_with_backoff(
system_prompt,
completion_func=None,
model_kwargs=None,
deepthought=False,
tracer: dict = {},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=gemini_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, api_base_url, model_kwargs, tracer),
args=(
g,
messages,
system_prompt,
model_name,
temperature,
api_key,
api_base_url,
model_kwargs,
deepthought,
tracer,
),
)
t.start()
return g
@@ -167,6 +197,7 @@ def gemini_llm_thread(
api_key,
api_base_url=None,
model_kwargs=None,
deepthought=False,
tracer: dict = {},
):
try:
@@ -177,10 +208,15 @@ def gemini_llm_thread(
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
thinking_config = None
if deepthought and model_name.startswith("gemini-2-5"):
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
temperature=temperature,
thinking_config=thinking_config,
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
stop_sequences=["Notes:\n["],
safety_settings=SAFETY_SETTINGS,
@@ -202,7 +238,10 @@ def gemini_llm_thread(
# Calculate cost of chat
input_tokens = chunk.usage_metadata.prompt_token_count
output_tokens = chunk.usage_metadata.candidates_token_count
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, usage=tracer.get("usage"))
thought_tokens = chunk.usage_metadata.thoughts_token_count or 0
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
)
# Save conversation trace
tracer["chat_model"] = model_name

View File

@@ -1290,6 +1290,7 @@ async def send_message_to_model_wrapper(
model=chat_model_name,
response_type=response_type,
response_schema=response_schema,
deepthought=deepthought,
api_base_url=api_base_url,
tracer=tracer,
)
@@ -1593,6 +1594,7 @@ def generate_chat_response(
generated_files=raw_generated_files,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
deepthought=deepthought,
tracer=tracer,
)

View File

@@ -49,6 +49,8 @@ model_to_cost: Dict[str, Dict[str, float]] = {
"gemini-1.5-pro": {"input": 1.25, "output": 5.00},
"gemini-1.5-pro-002": {"input": 1.25, "output": 5.00},
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
"gemini-2.5-flash-preview-04-17": {"input": 0.15, "output": 0.60, "thought": 3.50},
"gemini-2.5-pro-preview-03-25": {"input": 1.25, "output": 10.0},
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0, "cache_read": 0.08, "cache_write": 1.0},
"claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0, "cache_read": 0.08, "cache_write": 1.0},

View File

@@ -601,6 +601,7 @@ def get_cost_of_chat_message(
model_name: str,
input_tokens: int = 0,
output_tokens: int = 0,
thought_tokens: int = 0,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
prev_cost: float = 0.0,
@@ -612,10 +613,11 @@ def get_cost_of_chat_message(
# Calculate cost of input and output tokens. Costs are per million tokens
input_cost = constants.model_to_cost.get(model_name, {}).get("input", 0) * (input_tokens / 1e6)
output_cost = constants.model_to_cost.get(model_name, {}).get("output", 0) * (output_tokens / 1e6)
thought_cost = constants.model_to_cost.get(model_name, {}).get("thought", 0) * (thought_tokens / 1e6)
cache_read_cost = constants.model_to_cost.get(model_name, {}).get("cache_read", 0) * (cache_read_tokens / 1e6)
cache_write_cost = constants.model_to_cost.get(model_name, {}).get("cache_write", 0) * (cache_write_tokens / 1e6)
return input_cost + output_cost + cache_read_cost + cache_write_cost + prev_cost
return input_cost + output_cost + thought_cost + cache_read_cost + cache_write_cost + prev_cost
def get_chat_usage_metrics(
@@ -624,6 +626,7 @@ def get_chat_usage_metrics(
output_tokens: int = 0,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
thought_tokens: int = 0,
usage: dict = {},
cost: float = None,
):
@@ -633,6 +636,7 @@ def get_chat_usage_metrics(
prev_usage = usage or {
"input_tokens": 0,
"output_tokens": 0,
"thought_tokens": 0,
"cache_read_tokens": 0,
"cache_write_tokens": 0,
"cost": 0.0,
@@ -640,11 +644,18 @@ def get_chat_usage_metrics(
return {
"input_tokens": prev_usage["input_tokens"] + input_tokens,
"output_tokens": prev_usage["output_tokens"] + output_tokens,
"thought_tokens": prev_usage.get("thought_tokens", 0) + thought_tokens,
"cache_read_tokens": prev_usage.get("cache_read_tokens", 0) + cache_read_tokens,
"cache_write_tokens": prev_usage.get("cache_write_tokens", 0) + cache_write_tokens,
"cost": cost
or get_cost_of_chat_message(
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, prev_cost=prev_usage["cost"]
model_name,
input_tokens,
output_tokens,
thought_tokens,
cache_read_tokens,
cache_write_tokens,
prev_cost=prev_usage["cost"],
),
}