diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index c6f744fa..90cd4df9 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -14,6 +14,7 @@ from khoj.processor.conversation.openai.utils import ( from khoj.processor.conversation.utils import ( construct_structured_message, generate_chatml_messages_with_context, + remove_json_codeblock, ) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -85,6 +86,7 @@ def extract_questions( # Extract, Clean Message from GPT's Response try: response = response.strip() + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["queries"] if q.strip()] if not isinstance(response, list) or not response: diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 6444b14d..03bd17a3 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -289,3 +289,10 @@ def truncate_messages( def reciprocal_conversation_to_chatml(message_pair): """Convert a single back and forth between user and assistant to chatml format""" return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])] + + +def remove_json_codeblock(response): + """Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models""" + if response.startswith("```json") and response.endswith("```"): + response = response[7:-3] + return response diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 5687937a..f1b8ddd6 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -88,6 +88,7 @@ from khoj.processor.conversation.openai.gpt import converse, send_message_to_mod from khoj.processor.conversation.utils import ( ThreadedGenerator, generate_chatml_messages_with_context, + remove_json_codeblock, save_to_conversation_log, ) from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled @@ -298,9 +299,7 @@ async def aget_relevant_information_sources( try: response = response.strip() - # Remove any markdown json codeblock formatting if present (useful for gemma-2) - if response.startswith("```json"): - response = response[7:-3] + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["source"] if q.strip()] if not isinstance(response, list) or not response or len(response) == 0: @@ -353,7 +352,9 @@ async def aget_relevant_output_modes( response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object") try: - response = json.loads(response.strip()) + response = response.strip() + response = remove_json_codeblock(response) + response = json.loads(response) if is_none_or_empty(response): return ConversationCommand.Text @@ -433,9 +434,7 @@ async def generate_online_subqueries( # Validate that the response is a non-empty, JSON-serializable list try: response = response.strip() - # Remove any markdown json codeblock formatting if present (useful for gemma-2) - if response.startswith("```json") and response.endswith("```"): - response = response[7:-3] + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["queries"] if q.strip()] if not isinstance(response, list) or not response or len(response) == 0: