Support Claude 3.7 and use its extended thinking in research mode

Claude 3.7 Sonnet is Anthropic's first reasoning model. It provides a
single model/api capable of standard and extended thinking. Utilize
the extended thinking in Khoj's research mode.

Increase default max output tokens to 8K for Anthropic models.
This commit is contained in:
Debanjum
2025-03-11 01:54:12 +05:30
parent 69048a859f
commit 50f71be03d
7 changed files with 68 additions and 18 deletions

View File

@@ -86,7 +86,7 @@ dependencies = [
"pytz ~= 2024.1", "pytz ~= 2024.1",
"cron-descriptor == 1.4.3", "cron-descriptor == 1.4.3",
"django_apscheduler == 0.6.2", "django_apscheduler == 0.6.2",
"anthropic == 0.26.1", "anthropic == 0.49.0",
"docx2txt == 0.8", "docx2txt == 0.8",
"google-generativeai == 0.8.3", "google-generativeai == 0.8.3",
"pyjson5 == 1.6.7", "pyjson5 == 1.6.7",

View File

@@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
def extract_questions_anthropic( def extract_questions_anthropic(
text, text,
model: Optional[str] = "claude-instant-1.2", model: Optional[str] = "claude-3-7-sonnet-latest",
conversation_log={}, conversation_log={},
api_key=None, api_key=None,
temperature=0.7, temperature=0.7,
@@ -122,7 +122,7 @@ def extract_questions_anthropic(
return questions return questions
def anthropic_send_message_to_model(messages, api_key, model, response_type="text", tracer={}): def anthropic_send_message_to_model(messages, api_key, model, response_type="text", deepthought=False, tracer={}):
""" """
Send message to model Send message to model
""" """
@@ -135,6 +135,7 @@ def anthropic_send_message_to_model(messages, api_key, model, response_type="tex
model_name=model, model_name=model,
api_key=api_key, api_key=api_key,
response_type=response_type, response_type=response_type,
deepthought=deepthought,
tracer=tracer, tracer=tracer,
) )
@@ -145,7 +146,7 @@ def converse_anthropic(
online_results: Optional[Dict[str, Dict]] = None, online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None,
conversation_log={}, conversation_log={},
model: Optional[str] = "claude-3-5-sonnet-20241022", model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None, api_key: Optional[str] = None,
completion_func=None, completion_func=None,
conversation_commands=[ConversationCommand.Default], conversation_commands=[ConversationCommand.Default],
@@ -160,6 +161,7 @@ def converse_anthropic(
generated_files: List[FileAttachment] = None, generated_files: List[FileAttachment] = None,
program_execution_context: Optional[List[str]] = None, program_execution_context: Optional[List[str]] = None,
generated_asset_results: Dict[str, Dict] = {}, generated_asset_results: Dict[str, Dict] = {},
deepthought: Optional[bool] = False,
tracer: dict = {}, tracer: dict = {},
): ):
""" """
@@ -239,5 +241,6 @@ def converse_anthropic(
system_prompt=system_prompt, system_prompt=system_prompt,
completion_func=completion_func, completion_func=completion_func,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
deepthought=deepthought,
tracer=tracer, tracer=tracer,
) )

View File

@@ -17,10 +17,8 @@ from khoj.processor.conversation.utils import (
commit_conversation_trace, commit_conversation_trace,
get_image_from_url, get_image_from_url,
) )
from khoj.utils import state
from khoj.utils.helpers import ( from khoj.utils.helpers import (
get_chat_usage_metrics, get_chat_usage_metrics,
in_debug_mode,
is_none_or_empty, is_none_or_empty,
is_promptrace_enabled, is_promptrace_enabled,
) )
@@ -30,7 +28,8 @@ logger = logging.getLogger(__name__)
anthropic_clients: Dict[str, anthropic.Anthropic] = {} anthropic_clients: Dict[str, anthropic.Anthropic] = {}
DEFAULT_MAX_TOKENS_ANTHROPIC = 3000 DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
MAX_REASONING_TOKENS_ANTHROPIC = 12000
@retry( @retry(
@@ -42,12 +41,13 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
def anthropic_completion_with_backoff( def anthropic_completion_with_backoff(
messages, messages,
system_prompt, system_prompt,
model_name, model_name: str,
temperature=0, temperature=0,
api_key=None, api_key=None,
model_kwargs=None, model_kwargs=None,
max_tokens=None, max_tokens=None,
response_type="text", response_type="text",
deepthought=False,
tracer={}, tracer={},
) -> str: ) -> str:
if api_key not in anthropic_clients: if api_key not in anthropic_clients:
@@ -57,18 +57,24 @@ def anthropic_completion_with_backoff(
client = anthropic_clients[api_key] client = anthropic_clients[api_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages] formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
if response_type == "json_object": aggregated_response = ""
# Prefill model response with '{' to make it output a valid JSON object if response_type == "json_object" and not deepthought:
# Prefill model response with '{' to make it output a valid JSON object. Not supported with extended thinking.
formatted_messages += [{"role": "assistant", "content": "{"}] formatted_messages += [{"role": "assistant", "content": "{"}]
aggregated_response += "{"
aggregated_response = "{" if response_type == "json_object" else ""
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
final_message = None final_message = None
model_kwargs = model_kwargs or dict() model_kwargs = model_kwargs or dict()
if system_prompt: if system_prompt:
model_kwargs["system"] = system_prompt model_kwargs["system"] = system_prompt
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
if deepthought and model_name.startswith("claude-3-7"):
model_kwargs["thinking"] = {"type": "enabled", "budget_tokens": MAX_REASONING_TOKENS_ANTHROPIC}
max_tokens += MAX_REASONING_TOKENS_ANTHROPIC
# Temperature control not supported when using extended thinking
temperature = 1.0
with client.messages.stream( with client.messages.stream(
messages=formatted_messages, messages=formatted_messages,
model=model_name, # type: ignore model=model_name, # type: ignore
@@ -111,20 +117,41 @@ def anthropic_chat_completion_with_backoff(
system_prompt, system_prompt,
max_prompt_size=None, max_prompt_size=None,
completion_func=None, completion_func=None,
deepthought=False,
model_kwargs=None, model_kwargs=None,
tracer={}, tracer={},
): ):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread( t = Thread(
target=anthropic_llm_thread, target=anthropic_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer), args=(
g,
messages,
system_prompt,
model_name,
temperature,
api_key,
max_prompt_size,
deepthought,
model_kwargs,
tracer,
),
) )
t.start() t.start()
return g return g
def anthropic_llm_thread( def anthropic_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={} g,
messages,
system_prompt,
model_name,
temperature,
api_key,
max_prompt_size=None,
deepthought=False,
model_kwargs=None,
tracer={},
): ):
try: try:
if api_key not in anthropic_clients: if api_key not in anthropic_clients:
@@ -133,6 +160,14 @@ def anthropic_llm_thread(
else: else:
client: anthropic.Anthropic = anthropic_clients[api_key] client: anthropic.Anthropic = anthropic_clients[api_key]
model_kwargs = model_kwargs or dict()
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
if deepthought and model_name.startswith("claude-3-7"):
model_kwargs["thinking"] = {"type": "enabled", "budget_tokens": MAX_REASONING_TOKENS_ANTHROPIC}
max_tokens += MAX_REASONING_TOKENS_ANTHROPIC
# Temperature control not supported when using extended thinking
temperature = 1.0
formatted_messages: List[anthropic.types.MessageParam] = [ formatted_messages: List[anthropic.types.MessageParam] = [
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
] ]
@@ -145,8 +180,8 @@ def anthropic_llm_thread(
temperature=temperature, temperature=temperature,
system=system_prompt, system=system_prompt,
timeout=20, timeout=20,
max_tokens=DEFAULT_MAX_TOKENS_ANTHROPIC, max_tokens=max_tokens,
**(model_kwargs or dict()), **model_kwargs,
) as stream: ) as stream:
for text in stream.text_stream: for text in stream.text_stream:
aggregated_response += text aggregated_response += text

View File

@@ -61,6 +61,9 @@ model_to_prompt_size = {
"gemini-1.5-pro": 60000, "gemini-1.5-pro": 60000,
# Anthropic Models # Anthropic Models
"claude-3-5-sonnet-20241022": 60000, "claude-3-5-sonnet-20241022": 60000,
"claude-3-5-sonnet-latest": 60000,
"claude-3-7-sonnet-20250219": 60000,
"claude-3-7-sonnet-latest": 60000,
"claude-3-5-haiku-20241022": 60000, "claude-3-5-haiku-20241022": 60000,
# Offline Models # Offline Models
"bartowski/Qwen2.5-14B-Instruct-GGUF": 20000, "bartowski/Qwen2.5-14B-Instruct-GGUF": 20000,

View File

@@ -1125,6 +1125,7 @@ async def send_message_to_model_wrapper(
query: str, query: str,
system_message: str = "", system_message: str = "",
response_type: str = "text", response_type: str = "text",
deepthought: bool = False,
user: KhojUser = None, user: KhojUser = None,
query_images: List[str] = None, query_images: List[str] = None,
context: str = "", context: str = "",
@@ -1227,6 +1228,7 @@ async def send_message_to_model_wrapper(
api_key=api_key, api_key=api_key,
model=chat_model_name, model=chat_model_name,
response_type=response_type, response_type=response_type,
deepthought=deepthought,
tracer=tracer, tracer=tracer,
) )
elif model_type == ChatModel.ModelType.GOOGLE: elif model_type == ChatModel.ModelType.GOOGLE:
@@ -1425,11 +1427,13 @@ def generate_chat_response(
) )
query_to_run = q query_to_run = q
deepthought = False
if meta_research: if meta_research:
query_to_run = f"<query>{q}</query>\n<collected_research>\n{meta_research}\n</collected_research>" query_to_run = f"<query>{q}</query>\n<collected_research>\n{meta_research}\n</collected_research>"
compiled_references = [] compiled_references = []
online_results = {} online_results = {}
code_results = {} code_results = {}
deepthought = True
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed) chat_model = ConversationAdapters.get_valid_chat_model(user, conversation, is_subscribed)
vision_available = chat_model.vision_enabled vision_available = chat_model.vision_enabled
@@ -1513,6 +1517,7 @@ def generate_chat_response(
generated_files=raw_generated_files, generated_files=raw_generated_files,
generated_asset_results=generated_asset_results, generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context, program_execution_context=program_execution_context,
deepthought=deepthought,
tracer=tracer, tracer=tracer,
) )
elif chat_model.model_type == ChatModel.ModelType.GOOGLE: elif chat_model.model_type == ChatModel.ModelType.GOOGLE:

View File

@@ -95,6 +95,7 @@ async def apick_next_tool(
query=query, query=query,
context=function_planning_prompt, context=function_planning_prompt,
response_type="json_object", response_type="json_object",
deepthought=True,
user=user, user=user,
query_images=query_images, query_images=query_images,
query_files=query_files, query_files=query_files,

View File

@@ -48,6 +48,9 @@ model_to_cost: Dict[str, Dict[str, float]] = {
"gemini-1.5-pro-002": {"input": 1.25, "output": 5.00}, "gemini-1.5-pro-002": {"input": 1.25, "output": 5.00},
"gemini-2.0-flash": {"input": 0.10, "output": 0.40}, "gemini-2.0-flash": {"input": 0.10, "output": 0.40},
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_ # Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0}, "claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
"claude-3-5-sonnet-latest": {"input": 3.0, "output": 15.0},
"claude-3-7-sonnet-20250219": {"input": 3.0, "output": 15.0},
"claude-3-7-sonnet-latest": {"input": 3.0, "output": 15.0},
} }