diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index f0568ccc..feb587b2 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -16,6 +16,7 @@ from khoj.processor.conversation.anthropic.utils import ( from khoj.processor.conversation.utils import ( construct_structured_message, generate_chatml_messages_with_context, + remove_json_codeblock, ) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -91,15 +92,13 @@ def extract_questions_anthropic( model_name=model, temperature=temperature, api_key=api_key, + response_type="json_object", tracer=tracer, ) # Extract, Clean Message from Claude's Response try: - response = response.strip() - match = re.search(r"\{.*?\}", response) - if match: - response = match.group() + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["queries"] if q.strip()] if not isinstance(response, list) or not response: @@ -113,7 +112,7 @@ def extract_questions_anthropic( return questions -def anthropic_send_message_to_model(messages, api_key, model, tracer={}): +def anthropic_send_message_to_model(messages, api_key, model, response_type="text", tracer={}): """ Send message to model """ @@ -125,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model, tracer={}): system_prompt=system_prompt, model_name=model, api_key=api_key, + response_type=response_type, tracer=tracer, ) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 6673555b..cdce63c6 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -35,7 +35,15 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000 reraise=True, ) def anthropic_completion_with_backoff( - messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={} + messages, + system_prompt, + model_name, + temperature=0, + api_key=None, + model_kwargs=None, + max_tokens=None, + response_type="text", + tracer={}, ) -> str: if api_key not in anthropic_clients: client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) @@ -44,8 +52,11 @@ def anthropic_completion_with_backoff( client = anthropic_clients[api_key] formatted_messages = [{"role": message.role, "content": message.content} for message in messages] + if response_type == "json_object": + # Prefill model response with '{' to make it output a valid JSON object + formatted_messages += [{"role": "assistant", "content": "{"}] - aggregated_response = "" + aggregated_response = "{" if response_type == "json_object" else "" max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC model_kwargs = model_kwargs or dict() diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 44f24d4b..e9538c0c 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -16,6 +16,7 @@ from khoj.processor.conversation.google.utils import ( from khoj.processor.conversation.utils import ( construct_structured_message, generate_chatml_messages_with_context, + remove_json_codeblock, ) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -92,10 +93,7 @@ def extract_questions_gemini( # Extract, Clean Message from Gemini's Response try: - response = response.strip() - match = re.search(r"\{.*?\}", response) - if match: - response = match.group() + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["queries"] if q.strip()] if not isinstance(response, list) or not response: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 52958cc1..3fd0aeb1 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -978,6 +978,7 @@ async def send_message_to_model_wrapper( messages=truncated_messages, api_key=api_key, model=chat_model, + response_type=response_type, tracer=tracer, ) elif model_type == ChatModelOptions.ModelType.GOOGLE: @@ -1078,6 +1079,7 @@ def send_message_to_model_wrapper_sync( messages=truncated_messages, api_key=api_key, model=chat_model, + response_type=response_type, tracer=tracer, )