Improve parsing complex json strings returned by LLM (#989)

- Improve escaping to load complex json objects
- Fallback to a more forgiving [json5](https://json5.org/) loader if `json.loads` cannot parse complex json str

This should reduce failures to pick research tool and run code by agent
This commit is contained in:
Debanjum
2024-11-28 11:01:39 -08:00
committed by GitHub
5 changed files with 60 additions and 8 deletions

View File

@@ -88,6 +88,7 @@ dependencies = [
"anthropic == 0.26.1", "anthropic == 0.26.1",
"docx2txt == 0.8", "docx2txt == 0.8",
"google-generativeai == 0.8.3", "google-generativeai == 0.8.3",
"pyjson5 == 1.6.7",
] ]
dynamic = ["version"] dynamic = ["version"]

View File

@@ -5,6 +5,7 @@ import math
import mimetypes import mimetypes
import os import os
import queue import queue
import re
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
@@ -14,6 +15,7 @@ from time import perf_counter
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import PIL.Image import PIL.Image
import pyjson5
import requests import requests
import tiktoken import tiktoken
import yaml import yaml
@@ -538,6 +540,47 @@ def clean_code_python(code: str):
return code.strip().removeprefix("```python").removesuffix("```") return code.strip().removeprefix("```python").removesuffix("```")
def load_complex_json(json_str):
"""
Preprocess a raw JSON string to escape unescaped double quotes within value strings,
while preserving the JSON structure and already escaped quotes.
"""
def replace_unescaped_quotes(match):
# Get the content between colons and commas/end braces
content = match.group(1)
# Replace unescaped double, single quotes that aren't already escaped
# Uses negative lookbehind to avoid replacing already escaped quotes
# Replace " with \"
processed_dq = re.sub(r'(?<!\\)"', '\\"', content)
# Replace \' with \\'
processed_final = re.sub(r"(?<!\\)\\'", r"\\\\'", processed_dq)
return f': "{processed_final}"'
# Match content between : and either , or }
# This pattern looks for ': ' followed by any characters until , or }
pattern = r':\s*"(.*?)(?<!\\)"(?=[,}])'
# Process the JSON string
cleaned = clean_json(rf"{json_str}")
processed = re.sub(pattern, replace_unescaped_quotes, cleaned)
# See which json loader can load the processed JSON as valid
errors = []
json_loaders_to_try = [json.loads, pyjson5.loads]
for loads in json_loaders_to_try:
try:
return loads(processed)
except (json.JSONDecodeError, pyjson5.Json5Exception) as e:
errors.append(f"{type(e).__name__}: {str(e)}")
# If all loaders fail, raise the aggregated error
raise ValueError(
f"Failed to load JSON with errors: {'; '.join(errors)}\n\n"
f"While attempting to load this cleaned JSON:\n{processed}"
)
def defilter_query(query: str): def defilter_query(query: str):
"""Remove any query filters in query""" """Remove any query filters in query"""
defiltered_query = query defiltered_query = query

View File

@@ -1,6 +1,5 @@
import base64 import base64
import datetime import datetime
import json
import logging import logging
import mimetypes import mimetypes
import os import os
@@ -15,8 +14,8 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ChatEvent, ChatEvent,
clean_code_python, clean_code_python,
clean_json,
construct_chat_history, construct_chat_history,
load_complex_json,
) )
from khoj.routers.helpers import send_message_to_model_wrapper from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import is_none_or_empty, timer, truncate_code_context from khoj.utils.helpers import is_none_or_empty, timer, truncate_code_context
@@ -135,8 +134,7 @@ async def generate_python_code(
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
response = clean_json(response) response = load_complex_json(response)
response = json.loads(response)
code = response.get("code", "").strip() code = response.get("code", "").strip()
input_files = response.get("input_files", []) input_files = response.get("input_files", [])
input_links = response.get("input_links", []) input_links = response.get("input_links", [])

View File

@@ -1,4 +1,3 @@
import json
import logging import logging
from datetime import datetime from datetime import datetime
from typing import Callable, Dict, List, Optional from typing import Callable, Dict, List, Optional
@@ -10,10 +9,10 @@ from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
InformationCollectionIteration, InformationCollectionIteration,
clean_json,
construct_chat_history, construct_chat_history,
construct_iteration_history, construct_iteration_history,
construct_tool_chat_history, construct_tool_chat_history,
load_complex_json,
) )
from khoj.processor.tools.online_search import read_webpages, search_online from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code from khoj.processor.tools.run_code import run_code
@@ -106,8 +105,7 @@ async def apick_next_tool(
return return
try: try:
response = clean_json(response) response = load_complex_json(response)
response = json.loads(response)
selected_tool = response.get("tool", None) selected_tool = response.get("tool", None)
generated_query = response.get("query", None) generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None) scratchpad = response.get("scratchpad", None)

View File

@@ -104,6 +104,18 @@ class TestTruncateMessage:
assert truncated_chat_history[0] != copy_big_chat_message assert truncated_chat_history[0] != copy_big_chat_message
def test_load_complex_raw_json_string():
# Arrange
raw_json = r"""{"key": "value with unescaped " and unescaped \' and escaped \" and escaped \\'"}"""
expeced_json = {"key": "value with unescaped \" and unescaped \\' and escaped \" and escaped \\'"}
# Act
parsed_json = utils.load_complex_json(raw_json)
# Assert
assert parsed_json == expeced_json
def generate_content(count): def generate_content(count):
return " ".join([f"{index}" for index, _ in enumerate(range(count))]) return " ".join([f"{index}" for index, _ in enumerate(range(count))])