mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Support deepthought in research mode with new Gemini 2.5 reasoning model
The 2.5 flash model is the first hybrid reasoning models by Google - Track costs of thoughts separately as they are priced differently
This commit is contained in:
@@ -88,7 +88,7 @@ dependencies = [
|
|||||||
"django_apscheduler == 0.7.0",
|
"django_apscheduler == 0.7.0",
|
||||||
"anthropic == 0.49.0",
|
"anthropic == 0.49.0",
|
||||||
"docx2txt == 0.8",
|
"docx2txt == 0.8",
|
||||||
"google-genai == 1.5.0",
|
"google-genai == 1.11.0",
|
||||||
"google-auth ~= 2.23.3",
|
"google-auth ~= 2.23.3",
|
||||||
"pyjson5 == 1.6.7",
|
"pyjson5 == 1.6.7",
|
||||||
"resend == 1.0.1",
|
"resend == 1.0.1",
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ def anthropic_completion_with_backoff(
|
|||||||
cache_read_tokens = final_message.usage.cache_read_input_tokens
|
cache_read_tokens = final_message.usage.cache_read_input_tokens
|
||||||
cache_write_tokens = final_message.usage.cache_creation_input_tokens
|
cache_write_tokens = final_message.usage.cache_creation_input_tokens
|
||||||
tracer["usage"] = get_chat_usage_metrics(
|
tracer["usage"] = get_chat_usage_metrics(
|
||||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, tracer.get("usage")
|
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
@@ -213,7 +213,7 @@ def anthropic_llm_thread(
|
|||||||
cache_read_tokens = final_message.usage.cache_read_input_tokens
|
cache_read_tokens = final_message.usage.cache_read_input_tokens
|
||||||
cache_write_tokens = final_message.usage.cache_creation_input_tokens
|
cache_write_tokens = final_message.usage.cache_creation_input_tokens
|
||||||
tracer["usage"] = get_chat_usage_metrics(
|
tracer["usage"] = get_chat_usage_metrics(
|
||||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, tracer.get("usage")
|
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
|
|||||||
@@ -133,6 +133,7 @@ def gemini_send_message_to_model(
|
|||||||
response_type="text",
|
response_type="text",
|
||||||
response_schema=None,
|
response_schema=None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
|
deepthought=False,
|
||||||
tracer={},
|
tracer={},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -154,6 +155,7 @@ def gemini_send_message_to_model(
|
|||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -181,6 +183,7 @@ def converse_gemini(
|
|||||||
generated_files: List[FileAttachment] = None,
|
generated_files: List[FileAttachment] = None,
|
||||||
generated_asset_results: Dict[str, Dict] = {},
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
program_execution_context: List[str] = None,
|
program_execution_context: List[str] = None,
|
||||||
|
deepthought: Optional[bool] = False,
|
||||||
tracer={},
|
tracer={},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -260,5 +263,6 @@ def converse_gemini(
|
|||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
completion_func=completion_func,
|
completion_func=completion_func,
|
||||||
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ logger = logging.getLogger(__name__)
|
|||||||
gemini_clients: Dict[str, genai.Client] = {}
|
gemini_clients: Dict[str, genai.Client] = {}
|
||||||
|
|
||||||
MAX_OUTPUT_TOKENS_GEMINI = 8192
|
MAX_OUTPUT_TOKENS_GEMINI = 8192
|
||||||
|
MAX_REASONING_TOKENS_GEMINI = 10000
|
||||||
|
|
||||||
SAFETY_SETTINGS = [
|
SAFETY_SETTINGS = [
|
||||||
gtypes.SafetySetting(
|
gtypes.SafetySetting(
|
||||||
category=gtypes.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
category=gtypes.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||||
@@ -78,7 +80,15 @@ def get_gemini_client(api_key, api_base_url=None) -> genai.Client:
|
|||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def gemini_completion_with_backoff(
|
def gemini_completion_with_backoff(
|
||||||
messages, system_prompt, model_name, temperature=1.0, api_key=None, api_base_url=None, model_kwargs=None, tracer={}
|
messages,
|
||||||
|
system_prompt,
|
||||||
|
model_name: str,
|
||||||
|
temperature=1.0,
|
||||||
|
api_key=None,
|
||||||
|
api_base_url: str = None,
|
||||||
|
model_kwargs=None,
|
||||||
|
deepthought=False,
|
||||||
|
tracer={},
|
||||||
) -> str:
|
) -> str:
|
||||||
client = gemini_clients.get(api_key)
|
client = gemini_clients.get(api_key)
|
||||||
if not client:
|
if not client:
|
||||||
@@ -92,10 +102,15 @@ def gemini_completion_with_backoff(
|
|||||||
if model_kwargs and model_kwargs.get("response_schema"):
|
if model_kwargs and model_kwargs.get("response_schema"):
|
||||||
response_schema = clean_response_schema(model_kwargs["response_schema"])
|
response_schema = clean_response_schema(model_kwargs["response_schema"])
|
||||||
|
|
||||||
|
thinking_config = None
|
||||||
|
if deepthought and model_name.startswith("gemini-2-5"):
|
||||||
|
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
|
||||||
|
|
||||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||||
config = gtypes.GenerateContentConfig(
|
config = gtypes.GenerateContentConfig(
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
thinking_config=thinking_config,
|
||||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||||
safety_settings=SAFETY_SETTINGS,
|
safety_settings=SAFETY_SETTINGS,
|
||||||
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
|
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
|
||||||
@@ -119,7 +134,10 @@ def gemini_completion_with_backoff(
|
|||||||
# Aggregate cost of chat
|
# Aggregate cost of chat
|
||||||
input_tokens = response.usage_metadata.prompt_token_count if response else 0
|
input_tokens = response.usage_metadata.prompt_token_count if response else 0
|
||||||
output_tokens = response.usage_metadata.candidates_token_count if response else 0
|
output_tokens = response.usage_metadata.candidates_token_count if response else 0
|
||||||
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, usage=tracer.get("usage"))
|
thought_tokens = response.usage_metadata.thoughts_token_count or 0 if response else 0
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(
|
||||||
|
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||||
|
)
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
@@ -147,12 +165,24 @@ def gemini_chat_completion_with_backoff(
|
|||||||
system_prompt,
|
system_prompt,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
|
deepthought=False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||||
t = Thread(
|
t = Thread(
|
||||||
target=gemini_llm_thread,
|
target=gemini_llm_thread,
|
||||||
args=(g, messages, system_prompt, model_name, temperature, api_key, api_base_url, model_kwargs, tracer),
|
args=(
|
||||||
|
g,
|
||||||
|
messages,
|
||||||
|
system_prompt,
|
||||||
|
model_name,
|
||||||
|
temperature,
|
||||||
|
api_key,
|
||||||
|
api_base_url,
|
||||||
|
model_kwargs,
|
||||||
|
deepthought,
|
||||||
|
tracer,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
@@ -167,6 +197,7 @@ def gemini_llm_thread(
|
|||||||
api_key,
|
api_key,
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
|
deepthought=False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
@@ -177,10 +208,15 @@ def gemini_llm_thread(
|
|||||||
|
|
||||||
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||||
|
|
||||||
|
thinking_config = None
|
||||||
|
if deepthought and model_name.startswith("gemini-2-5"):
|
||||||
|
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
|
||||||
|
|
||||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||||
config = gtypes.GenerateContentConfig(
|
config = gtypes.GenerateContentConfig(
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
thinking_config=thinking_config,
|
||||||
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
|
||||||
stop_sequences=["Notes:\n["],
|
stop_sequences=["Notes:\n["],
|
||||||
safety_settings=SAFETY_SETTINGS,
|
safety_settings=SAFETY_SETTINGS,
|
||||||
@@ -202,7 +238,10 @@ def gemini_llm_thread(
|
|||||||
# Calculate cost of chat
|
# Calculate cost of chat
|
||||||
input_tokens = chunk.usage_metadata.prompt_token_count
|
input_tokens = chunk.usage_metadata.prompt_token_count
|
||||||
output_tokens = chunk.usage_metadata.candidates_token_count
|
output_tokens = chunk.usage_metadata.candidates_token_count
|
||||||
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, usage=tracer.get("usage"))
|
thought_tokens = chunk.usage_metadata.thoughts_token_count or 0
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(
|
||||||
|
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||||
|
)
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
|
|||||||
@@ -1290,6 +1290,7 @@ async def send_message_to_model_wrapper(
|
|||||||
model=chat_model_name,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
response_schema=response_schema,
|
response_schema=response_schema,
|
||||||
|
deepthought=deepthought,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
@@ -1593,6 +1594,7 @@ def generate_chat_response(
|
|||||||
generated_files=raw_generated_files,
|
generated_files=raw_generated_files,
|
||||||
generated_asset_results=generated_asset_results,
|
generated_asset_results=generated_asset_results,
|
||||||
program_execution_context=program_execution_context,
|
program_execution_context=program_execution_context,
|
||||||
|
deepthought=deepthought,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ model_to_cost: Dict[str, Dict[str, float]] = {
|
|||||||
"gemini-1.5-pro": {"input": 1.25, "output": 5.00},
|
"gemini-1.5-pro": {"input": 1.25, "output": 5.00},
|
||||||
"gemini-1.5-pro-002": {"input": 1.25, "output": 5.00},
|
"gemini-1.5-pro-002": {"input": 1.25, "output": 5.00},
|
||||||
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
||||||
|
"gemini-2.5-flash-preview-04-17": {"input": 0.15, "output": 0.60, "thought": 3.50},
|
||||||
|
"gemini-2.5-pro-preview-03-25": {"input": 1.25, "output": 10.0},
|
||||||
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api
|
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api
|
||||||
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0, "cache_read": 0.08, "cache_write": 1.0},
|
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0, "cache_read": 0.08, "cache_write": 1.0},
|
||||||
"claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0, "cache_read": 0.08, "cache_write": 1.0},
|
"claude-3-5-haiku@20241022": {"input": 1.0, "output": 5.0, "cache_read": 0.08, "cache_write": 1.0},
|
||||||
|
|||||||
@@ -601,6 +601,7 @@ def get_cost_of_chat_message(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
input_tokens: int = 0,
|
input_tokens: int = 0,
|
||||||
output_tokens: int = 0,
|
output_tokens: int = 0,
|
||||||
|
thought_tokens: int = 0,
|
||||||
cache_read_tokens: int = 0,
|
cache_read_tokens: int = 0,
|
||||||
cache_write_tokens: int = 0,
|
cache_write_tokens: int = 0,
|
||||||
prev_cost: float = 0.0,
|
prev_cost: float = 0.0,
|
||||||
@@ -612,10 +613,11 @@ def get_cost_of_chat_message(
|
|||||||
# Calculate cost of input and output tokens. Costs are per million tokens
|
# Calculate cost of input and output tokens. Costs are per million tokens
|
||||||
input_cost = constants.model_to_cost.get(model_name, {}).get("input", 0) * (input_tokens / 1e6)
|
input_cost = constants.model_to_cost.get(model_name, {}).get("input", 0) * (input_tokens / 1e6)
|
||||||
output_cost = constants.model_to_cost.get(model_name, {}).get("output", 0) * (output_tokens / 1e6)
|
output_cost = constants.model_to_cost.get(model_name, {}).get("output", 0) * (output_tokens / 1e6)
|
||||||
|
thought_cost = constants.model_to_cost.get(model_name, {}).get("thought", 0) * (thought_tokens / 1e6)
|
||||||
cache_read_cost = constants.model_to_cost.get(model_name, {}).get("cache_read", 0) * (cache_read_tokens / 1e6)
|
cache_read_cost = constants.model_to_cost.get(model_name, {}).get("cache_read", 0) * (cache_read_tokens / 1e6)
|
||||||
cache_write_cost = constants.model_to_cost.get(model_name, {}).get("cache_write", 0) * (cache_write_tokens / 1e6)
|
cache_write_cost = constants.model_to_cost.get(model_name, {}).get("cache_write", 0) * (cache_write_tokens / 1e6)
|
||||||
|
|
||||||
return input_cost + output_cost + cache_read_cost + cache_write_cost + prev_cost
|
return input_cost + output_cost + thought_cost + cache_read_cost + cache_write_cost + prev_cost
|
||||||
|
|
||||||
|
|
||||||
def get_chat_usage_metrics(
|
def get_chat_usage_metrics(
|
||||||
@@ -624,6 +626,7 @@ def get_chat_usage_metrics(
|
|||||||
output_tokens: int = 0,
|
output_tokens: int = 0,
|
||||||
cache_read_tokens: int = 0,
|
cache_read_tokens: int = 0,
|
||||||
cache_write_tokens: int = 0,
|
cache_write_tokens: int = 0,
|
||||||
|
thought_tokens: int = 0,
|
||||||
usage: dict = {},
|
usage: dict = {},
|
||||||
cost: float = None,
|
cost: float = None,
|
||||||
):
|
):
|
||||||
@@ -633,6 +636,7 @@ def get_chat_usage_metrics(
|
|||||||
prev_usage = usage or {
|
prev_usage = usage or {
|
||||||
"input_tokens": 0,
|
"input_tokens": 0,
|
||||||
"output_tokens": 0,
|
"output_tokens": 0,
|
||||||
|
"thought_tokens": 0,
|
||||||
"cache_read_tokens": 0,
|
"cache_read_tokens": 0,
|
||||||
"cache_write_tokens": 0,
|
"cache_write_tokens": 0,
|
||||||
"cost": 0.0,
|
"cost": 0.0,
|
||||||
@@ -640,11 +644,18 @@ def get_chat_usage_metrics(
|
|||||||
return {
|
return {
|
||||||
"input_tokens": prev_usage["input_tokens"] + input_tokens,
|
"input_tokens": prev_usage["input_tokens"] + input_tokens,
|
||||||
"output_tokens": prev_usage["output_tokens"] + output_tokens,
|
"output_tokens": prev_usage["output_tokens"] + output_tokens,
|
||||||
|
"thought_tokens": prev_usage.get("thought_tokens", 0) + thought_tokens,
|
||||||
"cache_read_tokens": prev_usage.get("cache_read_tokens", 0) + cache_read_tokens,
|
"cache_read_tokens": prev_usage.get("cache_read_tokens", 0) + cache_read_tokens,
|
||||||
"cache_write_tokens": prev_usage.get("cache_write_tokens", 0) + cache_write_tokens,
|
"cache_write_tokens": prev_usage.get("cache_write_tokens", 0) + cache_write_tokens,
|
||||||
"cost": cost
|
"cost": cost
|
||||||
or get_cost_of_chat_message(
|
or get_cost_of_chat_message(
|
||||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, prev_cost=prev_usage["cost"]
|
model_name,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
thought_tokens,
|
||||||
|
cache_read_tokens,
|
||||||
|
cache_write_tokens,
|
||||||
|
prev_cost=prev_usage["cost"],
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user