mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Do not overwrite charts created in previous code tool use during research
This commit is contained in:
@@ -252,8 +252,12 @@ async def execute_e2b(code: str, input_files: list[dict]) -> dict[str, Any]:
|
|||||||
|
|
||||||
# Identify new files created during execution
|
# Identify new files created during execution
|
||||||
new_files = set(E2bFile(f.name, f.path) for f in await sandbox.files.list("~")) - original_files
|
new_files = set(E2bFile(f.name, f.path) for f in await sandbox.files.list("~")) - original_files
|
||||||
|
|
||||||
# Read newly created files in parallel
|
# Read newly created files in parallel
|
||||||
download_tasks = [sandbox.files.read(f.path, request_timeout=30) for f in new_files]
|
def read_format(f):
|
||||||
|
return "bytes" if Path(f.name).suffix in image_file_ext else "text"
|
||||||
|
|
||||||
|
download_tasks = [sandbox.files.read(f.path, format=read_format(f), request_timeout=30) for f in new_files]
|
||||||
downloaded_files = await asyncio.gather(*download_tasks)
|
downloaded_files = await asyncio.gather(*download_tasks)
|
||||||
for f, content in zip(new_files, downloaded_files):
|
for f, content in zip(new_files, downloaded_files):
|
||||||
if isinstance(content, bytes):
|
if isinstance(content, bytes):
|
||||||
@@ -261,23 +265,12 @@ async def execute_e2b(code: str, input_files: list[dict]) -> dict[str, Any]:
|
|||||||
b64_data = base64.b64encode(content).decode("utf-8")
|
b64_data = base64.b64encode(content).decode("utf-8")
|
||||||
elif Path(f.name).suffix in image_file_ext:
|
elif Path(f.name).suffix in image_file_ext:
|
||||||
# Ignore image files as they are extracted from execution results below for inline display
|
# Ignore image files as they are extracted from execution results below for inline display
|
||||||
continue
|
b64_data = base64.b64encode(content).decode("utf-8")
|
||||||
else:
|
else:
|
||||||
# Text files - encode utf-8 string as base64
|
# Text files - encode utf-8 string as base64
|
||||||
b64_data = content
|
b64_data = content
|
||||||
output_files.append({"filename": f.name, "b64_data": b64_data})
|
output_files.append({"filename": f.name, "b64_data": b64_data})
|
||||||
|
|
||||||
# Collect output files from execution results
|
|
||||||
# Repect ordering of output result types to disregard text output associated with images
|
|
||||||
output_result_types = ["png", "jpeg", "svg", "text", "markdown", "json"]
|
|
||||||
for idx, result in enumerate(execution.results):
|
|
||||||
if getattr(result, "chart", None):
|
|
||||||
continue
|
|
||||||
for result_type in output_result_types:
|
|
||||||
if b64_data := getattr(result, result_type, None):
|
|
||||||
output_files.append({"filename": f"{idx}.{result_type}", "b64_data": b64_data})
|
|
||||||
break
|
|
||||||
|
|
||||||
# collect logs
|
# collect logs
|
||||||
success = not execution.error and not execution.logs.stderr
|
success = not execution.error and not execution.logs.stderr
|
||||||
stdout = "\n".join(execution.logs.stdout)
|
stdout = "\n".join(execution.logs.stdout)
|
||||||
|
|||||||
Reference in New Issue
Block a user