mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Improve load of complex json objects. Use it to pick tool, run code
Gemini doesn't work well when trying to output json objects. Using it to output raw json strings with complex, multi-line structures requires more intense clean-up of raw json string for parsing
This commit is contained in:
@@ -5,6 +5,7 @@ import math
|
||||
import mimetypes
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
@@ -538,6 +539,46 @@ def clean_code_python(code: str):
|
||||
return code.strip().removeprefix("```python").removesuffix("```")
|
||||
|
||||
|
||||
def load_complex_json(json_str):
|
||||
"""
|
||||
Preprocess a raw JSON string to escape unescaped double quotes within value strings,
|
||||
while preserving the JSON structure and already escaped quotes.
|
||||
"""
|
||||
|
||||
def replace_unescaped_quotes(match):
|
||||
# Get the content between colons and commas/end braces
|
||||
content = match.group(1)
|
||||
# Replace unescaped double, single quotes that aren't already escaped
|
||||
# Uses negative lookbehind to avoid replacing already escaped quotes
|
||||
# Replace " with \"
|
||||
processed_dq = re.sub(r'(?<!\\)"', '\\"', content)
|
||||
# Replace \' with \\'
|
||||
processed_final = re.sub(r"(?<!\\)\\'", r"\\\\'", processed_dq)
|
||||
return f': "{processed_final}"'
|
||||
|
||||
# Match content between : and either , or }
|
||||
# This pattern looks for ': ' followed by any characters until , or }
|
||||
pattern = r':\s*"(.*?)(?<!\\)"(?=[,}])'
|
||||
|
||||
# Process the JSON string
|
||||
cleaned = clean_json(rf"{json_str}")
|
||||
processed = re.sub(pattern, replace_unescaped_quotes, cleaned)
|
||||
|
||||
# See which json loader can load the processed JSON as valid
|
||||
errors = ""
|
||||
json_loaders_to_try = [json.loads]
|
||||
for loads in json_loaders_to_try:
|
||||
try:
|
||||
return loads(processed)
|
||||
except json.JSONDecodeError as e:
|
||||
errors += f"\n\n{e}"
|
||||
|
||||
# If all loaders fail, raise the aggregated error
|
||||
raise ValueError(
|
||||
f"Failed to load JSON with error: {errors}\n\nWhile attempting to load this cleaned JSON:\n{processed}"
|
||||
)
|
||||
|
||||
|
||||
def defilter_query(query: str):
|
||||
"""Remove any query filters in query"""
|
||||
defiltered_query = query
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
@@ -15,8 +14,8 @@ from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
clean_code_python,
|
||||
clean_json,
|
||||
construct_chat_history,
|
||||
load_complex_json,
|
||||
)
|
||||
from khoj.routers.helpers import send_message_to_model_wrapper
|
||||
from khoj.utils.helpers import is_none_or_empty, timer, truncate_code_context
|
||||
@@ -135,8 +134,7 @@ async def generate_python_code(
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
response = load_complex_json(response)
|
||||
code = response.get("code", "").strip()
|
||||
input_files = response.get("input_files", [])
|
||||
input_links = response.get("input_links", [])
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, List, Optional
|
||||
@@ -10,10 +9,10 @@ from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
InformationCollectionIteration,
|
||||
clean_json,
|
||||
construct_chat_history,
|
||||
construct_iteration_history,
|
||||
construct_tool_chat_history,
|
||||
load_complex_json,
|
||||
)
|
||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||
from khoj.processor.tools.run_code import run_code
|
||||
@@ -106,8 +105,7 @@ async def apick_next_tool(
|
||||
return
|
||||
|
||||
try:
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
response = load_complex_json(response)
|
||||
selected_tool = response.get("tool", None)
|
||||
generated_query = response.get("query", None)
|
||||
scratchpad = response.get("scratchpad", None)
|
||||
|
||||
@@ -104,6 +104,18 @@ class TestTruncateMessage:
|
||||
assert truncated_chat_history[0] != copy_big_chat_message
|
||||
|
||||
|
||||
def test_load_complex_raw_json_string():
|
||||
# Arrange
|
||||
raw_json = r"""{"key": "value with unescaped " and unescaped \' and escaped \" and escaped \\'"}"""
|
||||
expeced_json = {"key": "value with unescaped \" and unescaped \\' and escaped \" and escaped \\'"}
|
||||
|
||||
# Act
|
||||
parsed_json = utils.load_complex_json(raw_json)
|
||||
|
||||
# Assert
|
||||
assert parsed_json == expeced_json
|
||||
|
||||
|
||||
def generate_content(count):
|
||||
return " ".join([f"{index}" for index, _ in enumerate(range(count))])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user