Improve load of complex json objects. Use it to pick tool, run code

Gemini doesn't work well when trying to output json objects. Using it
to output raw json strings with complex, multi-line structures
requires more intense clean-up of raw json string for parsing
This commit is contained in:
Debanjum
2024-11-26 15:35:23 -08:00
parent 8cb0db0051
commit 70b7e7c73a
4 changed files with 57 additions and 8 deletions

View File

@@ -5,6 +5,7 @@ import math
import mimetypes
import os
import queue
import re
import uuid
from dataclasses import dataclass
from datetime import datetime
@@ -538,6 +539,46 @@ def clean_code_python(code: str):
return code.strip().removeprefix("```python").removesuffix("```")
def load_complex_json(json_str):
"""
Preprocess a raw JSON string to escape unescaped double quotes within value strings,
while preserving the JSON structure and already escaped quotes.
"""
def replace_unescaped_quotes(match):
# Get the content between colons and commas/end braces
content = match.group(1)
# Replace unescaped double, single quotes that aren't already escaped
# Uses negative lookbehind to avoid replacing already escaped quotes
# Replace " with \"
processed_dq = re.sub(r'(?<!\\)"', '\\"', content)
# Replace \' with \\'
processed_final = re.sub(r"(?<!\\)\\'", r"\\\\'", processed_dq)
return f': "{processed_final}"'
# Match content between : and either , or }
# This pattern looks for ': ' followed by any characters until , or }
pattern = r':\s*"(.*?)(?<!\\)"(?=[,}])'
# Process the JSON string
cleaned = clean_json(rf"{json_str}")
processed = re.sub(pattern, replace_unescaped_quotes, cleaned)
# See which json loader can load the processed JSON as valid
errors = ""
json_loaders_to_try = [json.loads]
for loads in json_loaders_to_try:
try:
return loads(processed)
except json.JSONDecodeError as e:
errors += f"\n\n{e}"
# If all loaders fail, raise the aggregated error
raise ValueError(
f"Failed to load JSON with error: {errors}\n\nWhile attempting to load this cleaned JSON:\n{processed}"
)
def defilter_query(query: str):
"""Remove any query filters in query"""
defiltered_query = query

View File

@@ -1,6 +1,5 @@
import base64
import datetime
import json
import logging
import mimetypes
import os
@@ -15,8 +14,8 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
ChatEvent,
clean_code_python,
clean_json,
construct_chat_history,
load_complex_json,
)
from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import is_none_or_empty, timer, truncate_code_context
@@ -135,8 +134,7 @@ async def generate_python_code(
)
# Validate that the response is a non-empty, JSON-serializable list
response = clean_json(response)
response = json.loads(response)
response = load_complex_json(response)
code = response.get("code", "").strip()
input_files = response.get("input_files", [])
input_links = response.get("input_links", [])

View File

@@ -1,4 +1,3 @@
import json
import logging
from datetime import datetime
from typing import Callable, Dict, List, Optional
@@ -10,10 +9,10 @@ from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
InformationCollectionIteration,
clean_json,
construct_chat_history,
construct_iteration_history,
construct_tool_chat_history,
load_complex_json,
)
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code
@@ -106,8 +105,7 @@ async def apick_next_tool(
return
try:
response = clean_json(response)
response = json.loads(response)
response = load_complex_json(response)
selected_tool = response.get("tool", None)
generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None)

View File

@@ -104,6 +104,18 @@ class TestTruncateMessage:
assert truncated_chat_history[0] != copy_big_chat_message
def test_load_complex_raw_json_string():
# Arrange
raw_json = r"""{"key": "value with unescaped " and unescaped \' and escaped \" and escaped \\'"}"""
expeced_json = {"key": "value with unescaped \" and unescaped \\' and escaped \" and escaped \\'"}
# Act
parsed_json = utils.load_complex_json(raw_json)
# Assert
assert parsed_json == expeced_json
def generate_content(count):
return " ".join([f"{index}" for index, _ in enumerate(range(count))])