Improve parsing complex json strings returned by LLM (#989)

- Improve escaping to load complex json objects
- Fallback to a more forgiving [json5](https://json5.org/) loader if `json.loads` cannot parse complex json str

This should reduce failures to pick research tool and run code by agent
This commit is contained in:
Debanjum
2024-11-28 11:01:39 -08:00
committed by GitHub
5 changed files with 60 additions and 8 deletions

View File

@@ -88,6 +88,7 @@ dependencies = [
"anthropic == 0.26.1",
"docx2txt == 0.8",
"google-generativeai == 0.8.3",
"pyjson5 == 1.6.7",
]
dynamic = ["version"]

View File

@@ -5,6 +5,7 @@ import math
import mimetypes
import os
import queue
import re
import uuid
from dataclasses import dataclass
from datetime import datetime
@@ -14,6 +15,7 @@ from time import perf_counter
from typing import Any, Callable, Dict, List, Optional
import PIL.Image
import pyjson5
import requests
import tiktoken
import yaml
@@ -538,6 +540,47 @@ def clean_code_python(code: str):
return code.strip().removeprefix("```python").removesuffix("```")
def load_complex_json(json_str):
"""
Preprocess a raw JSON string to escape unescaped double quotes within value strings,
while preserving the JSON structure and already escaped quotes.
"""
def replace_unescaped_quotes(match):
# Get the content between colons and commas/end braces
content = match.group(1)
# Replace unescaped double, single quotes that aren't already escaped
# Uses negative lookbehind to avoid replacing already escaped quotes
# Replace " with \"
processed_dq = re.sub(r'(?<!\\)"', '\\"', content)
# Replace \' with \\'
processed_final = re.sub(r"(?<!\\)\\'", r"\\\\'", processed_dq)
return f': "{processed_final}"'
# Match content between : and either , or }
# This pattern looks for ': ' followed by any characters until , or }
pattern = r':\s*"(.*?)(?<!\\)"(?=[,}])'
# Process the JSON string
cleaned = clean_json(rf"{json_str}")
processed = re.sub(pattern, replace_unescaped_quotes, cleaned)
# See which json loader can load the processed JSON as valid
errors = []
json_loaders_to_try = [json.loads, pyjson5.loads]
for loads in json_loaders_to_try:
try:
return loads(processed)
except (json.JSONDecodeError, pyjson5.Json5Exception) as e:
errors.append(f"{type(e).__name__}: {str(e)}")
# If all loaders fail, raise the aggregated error
raise ValueError(
f"Failed to load JSON with errors: {'; '.join(errors)}\n\n"
f"While attempting to load this cleaned JSON:\n{processed}"
)
def defilter_query(query: str):
"""Remove any query filters in query"""
defiltered_query = query

View File

@@ -1,6 +1,5 @@
import base64
import datetime
import json
import logging
import mimetypes
import os
@@ -15,8 +14,8 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
ChatEvent,
clean_code_python,
clean_json,
construct_chat_history,
load_complex_json,
)
from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import is_none_or_empty, timer, truncate_code_context
@@ -135,8 +134,7 @@ async def generate_python_code(
)
# Validate that the response is a non-empty, JSON-serializable list
response = clean_json(response)
response = json.loads(response)
response = load_complex_json(response)
code = response.get("code", "").strip()
input_files = response.get("input_files", [])
input_links = response.get("input_links", [])

View File

@@ -1,4 +1,3 @@
import json
import logging
from datetime import datetime
from typing import Callable, Dict, List, Optional
@@ -10,10 +9,10 @@ from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
InformationCollectionIteration,
clean_json,
construct_chat_history,
construct_iteration_history,
construct_tool_chat_history,
load_complex_json,
)
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code
@@ -106,8 +105,7 @@ async def apick_next_tool(
return
try:
response = clean_json(response)
response = json.loads(response)
response = load_complex_json(response)
selected_tool = response.get("tool", None)
generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None)

View File

@@ -104,6 +104,18 @@ class TestTruncateMessage:
assert truncated_chat_history[0] != copy_big_chat_message
def test_load_complex_raw_json_string():
# Arrange
raw_json = r"""{"key": "value with unescaped " and unescaped \' and escaped \" and escaped \\'"}"""
expeced_json = {"key": "value with unescaped \" and unescaped \\' and escaped \" and escaped \\'"}
# Act
parsed_json = utils.load_complex_json(raw_json)
# Assert
assert parsed_json == expeced_json
def generate_content(count):
return " ".join([f"{index}" for index, _ in enumerate(range(count))])