mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Support deepthought in research mode with new Gemini 2.5 reasoning model
The 2.5 flash model is the first hybrid reasoning models by Google - Track costs of thoughts separately as they are priced differently
This commit is contained in:
@@ -88,7 +88,7 @@ dependencies = [
|
||||
"django_apscheduler == 0.7.0",
|
||||
"anthropic == 0.49.0",
|
||||
"docx2txt == 0.8",
|
||||
"google-genai == 1.5.0",
|
||||
"google-genai == 1.11.0",
|
||||
"google-auth ~= 2.23.3",
|
||||
"pyjson5 == 1.6.7",
|
||||
"resend == 1.0.1",
|
||||
|
||||
@@ -108,7 +108,7 @@ def anthropic_completion_with_backoff(
|
||||
cache_read_tokens = final_message.usage.cache_read_input_tokens
|
||||
cache_write_tokens = final_message.usage.cache_creation_input_tokens
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, tracer.get("usage")
|
||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
||||
)
|
||||
|
||||
# Save conversation trace
|
||||
@@ -213,7 +213,7 @@ def anthropic_llm_thread(
|
||||
cache_read_tokens = final_message.usage.cache_read_input_tokens
|
||||
cache_write_tokens = final_message.usage.cache_creation_input_tokens
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, tracer.get("usage")
|
||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
||||
)
|
||||
|
||||
# Save conversation trace
|
||||
|
||||
@@ -133,6 +133,7 @@ def gemini_send_message_to_model(
|
||||
response_type="text",
|
||||
response_schema=None,
|
||||
model_kwargs=None,
|
||||
deepthought=False,
|
||||
tracer={},
|
||||
):
|
||||
"""
|
||||
@@ -154,6 +155,7 @@ def gemini_send_message_to_model(
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
model_kwargs=model_kwargs,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -181,6 +183,7 @@ def converse_gemini(
|
||||
generated_files: List[FileAttachment] = None,
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
program_execution_context: List[str] = None,
|
||||
deepthought: Optional[bool] = False,
|
||||
tracer={},
|
||||
):
|
||||
"""
|
||||
@@ -260,5 +263,6 @@ def converse_gemini(
|
||||
api_base_url=api_base_url,
|
||||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -36,6 +36,8 @@ logger = logging.getLogger(__name__)
|
||||
gemini_clients: Dict[str, genai.Client] = {}
|
||||
|
||||
MAX_OUTPUT_TOKENS_GEMINI = 8192
|
||||
MAX_REASONING_TOKENS_GEMINI = 10000
|
||||
|
||||
SAFETY_SETTINGS = [
|
||||
gtypes.SafetySetting(
|
||||
category=gtypes.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
@@ -78,7 +80,15 @@ def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
||||
reraise=True,
|
||||
)
|
||||
def gemini_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=1.0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
|
||||
messages,
|
||||
system_prompt,
|
||||
model_name: str,
|
||||
temperature=1.0,
|
||||
api_key=None,
|
||||
api_base_url: str = None,
|
||||
model_kwargs=None,
|
||||
deepthought=False,
|
||||
tracer={},
|
||||
) -> str:
|
||||
client = gemini_clients.get(api_key)
|
||||
if not client:
|
||||
@@ -92,10 +102,15 @@ def gemini_completion_with_backoff(
|
||||
if model_kwargs and model_kwargs.get("response_schema"):
|
||||
response_schema = clean_response_schema(model_kwargs["response_schema"])
|
||||
|
||||
thinking_config = None
|
||||
if deepthought and model_name.startswith("gemini-2-5"):
|
||||
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
|
||||
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
temperature=temperature,
|
||||
thinking_config=thinking_config,
|
||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
|
||||
@@ -119,7 +134,10 @@ def gemini_completion_with_backoff(
|
||||
# Aggregate cost of chat
|
||||
input_tokens = response.usage_metadata.prompt_token_count if response else 0
|
||||
output_tokens = response.usage_metadata.candidates_token_count if response else 0
|
||||
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, usage=tracer.get("usage"))
|
||||
thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||
)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
@@ -147,12 +165,24 @@ def gemini_chat_completion_with_backoff(
|
||||
system_prompt,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
deepthought=False,
|
||||
tracer: dict = {},
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=gemini_llm_thread,
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, api_base_url, model_kwargs, tracer),
|
||||
args=(
|
||||
g,
|
||||
messages,
|
||||
system_prompt,
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
api_base_url,
|
||||
model_kwargs,
|
||||
deepthought,
|
||||
tracer,
|
||||
),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
@@ -167,6 +197,7 @@ def gemini_llm_thread(
|
||||
api_key,
|
||||
api_base_url=None,
|
||||
model_kwargs=None,
|
||||
deepthought=False,
|
||||
tracer: dict = {},
|
||||
):
|
||||
try:
|
||||
@@ -177,10 +208,15 @@ def gemini_llm_thread(
|
||||
|
||||
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
|
||||
thinking_config = None
|
||||
if deepthought and model_name.startswith("gemini-2-5"):
|
||||
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
|
||||
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
config = gtypes.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
temperature=temperature,
|
||||
thinking_config=thinking_config,
|
||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||
stop_sequences=["Notes:\n["],
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
@@ -202,7 +238,10 @@ def gemini_llm_thread(
|
||||
# Calculate cost of chat
|
||||
input_tokens = chunk.usage_metadata.prompt_token_count
|
||||
output_tokens = chunk.usage_metadata.candidates_token_count
|
||||
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, usage=tracer.get("usage"))
|
||||
thought_tokens = chunk.usage_metadata.thoughts_token_count or 0
|
||||
tracer["usage"] = get_chat_usage_metrics(
|
||||
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||
)
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
|
||||
@@ -1290,6 +1290,7 @@ async def send_message_to_model_wrapper(
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
response_schema=response_schema,
|
||||
deepthought=deepthought,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
)
|
||||
@@ -1593,6 +1594,7 @@ def generate_chat_response(
|
||||
generated_files=raw_generated_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
program_execution_context=program_execution_context,
|
||||
deepthought=deepthought,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,6 +49,8 @@ model_to_cost: Dict[str, Dict[str, float]] = {
|
||||
"gemini-1.5-pro": {"input": 1.25, "output": 5.00},
|
||||
"gemini-1.5-pro-002": {"input": 1.25, "output": 5.00},
|
||||
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
||||
"gemini-2.5-flash-preview-04-17": {"input": 0.15, "output": 0.60, "thought": 3.50},
|
||||
"gemini-2.5-pro-preview-03-25": {"input": 1.25, "output": 10.0},
|
||||
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api
|
||||
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0, "cache_read": 0.08, "cache_write": 1.0},
|
||||
"claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0, "cache_read": 0.08, "cache_write": 1.0},
|
||||
|
||||
@@ -601,6 +601,7 @@ def get_cost_of_chat_message(
|
||||
model_name: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
thought_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
prev_cost: float = 0.0,
|
||||
@@ -612,10 +613,11 @@ def get_cost_of_chat_message(
|
||||
# Calculate cost of input and output tokens. Costs are per million tokens
|
||||
input_cost = constants.model_to_cost.get(model_name, {}).get("input", 0) * (input_tokens / 1e6)
|
||||
output_cost = constants.model_to_cost.get(model_name, {}).get("output", 0) * (output_tokens / 1e6)
|
||||
thought_cost = constants.model_to_cost.get(model_name, {}).get("thought", 0) * (thought_tokens / 1e6)
|
||||
cache_read_cost = constants.model_to_cost.get(model_name, {}).get("cache_read", 0) * (cache_read_tokens / 1e6)
|
||||
cache_write_cost = constants.model_to_cost.get(model_name, {}).get("cache_write", 0) * (cache_write_tokens / 1e6)
|
||||
|
||||
return input_cost + output_cost + cache_read_cost + cache_write_cost + prev_cost
|
||||
return input_cost + output_cost + thought_cost + cache_read_cost + cache_write_cost + prev_cost
|
||||
|
||||
|
||||
def get_chat_usage_metrics(
|
||||
@@ -624,6 +626,7 @@ def get_chat_usage_metrics(
|
||||
output_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
thought_tokens: int = 0,
|
||||
usage: dict = {},
|
||||
cost: float = None,
|
||||
):
|
||||
@@ -633,6 +636,7 @@ def get_chat_usage_metrics(
|
||||
prev_usage = usage or {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"thought_tokens": 0,
|
||||
"cache_read_tokens": 0,
|
||||
"cache_write_tokens": 0,
|
||||
"cost": 0.0,
|
||||
@@ -640,11 +644,18 @@ def get_chat_usage_metrics(
|
||||
return {
|
||||
"input_tokens": prev_usage["input_tokens"] + input_tokens,
|
||||
"output_tokens": prev_usage["output_tokens"] + output_tokens,
|
||||
"thought_tokens": prev_usage.get("thought_tokens", 0) + thought_tokens,
|
||||
"cache_read_tokens": prev_usage.get("cache_read_tokens", 0) + cache_read_tokens,
|
||||
"cache_write_tokens": prev_usage.get("cache_write_tokens", 0) + cache_write_tokens,
|
||||
"cost": cost
|
||||
or get_cost_of_chat_message(
|
||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, prev_cost=prev_usage["cost"]
|
||||
model_name,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
thought_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
prev_cost=prev_usage["cost"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user