Enforce json schema on more chat actors to improve schema compliance

Including infer webpage urls, gemini documents search, pick default
mode tools chat actors
This commit is contained in:
Debanjum
2025-03-27 18:33:08 +05:30
parent ccd9de7792
commit a387f638cd
2 changed files with 19 additions and 3 deletions

View File

@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
from pydantic import BaseModel
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
@@ -96,12 +97,16 @@ def extract_questions_gemini(
messages.append(ChatMessage(content=prompt, role="user"))
messages.append(ChatMessage(content=system_prompt, role="system"))
class DocumentQueries(BaseModel):
queries: List[str]
response = gemini_send_message_to_model(
messages,
api_key,
model,
api_base_url=api_base_url,
response_type="json_object",
response_schema=DocumentQueries,
tracer=tracer,
)

View File

@@ -399,10 +399,15 @@ async def aget_data_sources_and_output_format(
agent_chat_model = agent.chat_model if agent else None
class PickTools(BaseModel):
source: List[str]
output: str
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
relevant_tools_prompt,
response_type="json_object",
response_schema=PickTools,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
@@ -483,11 +488,15 @@ async def infer_webpage_urls(
agent_chat_model = agent.chat_model if agent else None
class WebpageUrls(BaseModel):
links: List[str]
with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt,
query_images=query_images,
response_type="json_object",
response_schema=WebpageUrls,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
@@ -563,11 +572,13 @@ async def generate_online_subqueries(
response = pyjson5.loads(response)
response = {q.strip() for q in response["queries"] if q.strip()}
if not isinstance(response, set) or not response or len(response) == 0:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
logger.error(
f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}"
)
return {q}
return response
except Exception as e:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
logger.error(f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}")
return {q}
@@ -2054,7 +2065,7 @@ def schedule_automation(
try:
user_timezone = pytz.timezone(timezone)
except pytz.UnknownTimeZoneError:
logger.error(f"Invalid timezone: {timezone}. Fallback to use UTC to schedule automation.")
logger.warning(f"Invalid timezone: {timezone}. Fallback to use UTC to schedule automation.")
user_timezone = pytz.utc
trigger = CronTrigger.from_crontab(crontime, user_timezone)