Enforce json schema on more chat actors to improve schema compliance

Including infer webpage urls, gemini documents search, pick default
mode tools chat actors
This commit is contained in:
Debanjum
2025-03-27 18:33:08 +05:30
parent ccd9de7792
commit a387f638cd
2 changed files with 19 additions and 3 deletions

View File

@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from pydantic import BaseModel
from khoj.database.models import Agent, ChatModel, KhojUser from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
@@ -96,12 +97,16 @@ def extract_questions_gemini(
messages.append(ChatMessage(content=prompt, role="user")) messages.append(ChatMessage(content=prompt, role="user"))
messages.append(ChatMessage(content=system_prompt, role="system")) messages.append(ChatMessage(content=system_prompt, role="system"))
class DocumentQueries(BaseModel):
queries: List[str]
response = gemini_send_message_to_model( response = gemini_send_message_to_model(
messages, messages,
api_key, api_key,
model, model,
api_base_url=api_base_url, api_base_url=api_base_url,
response_type="json_object", response_type="json_object",
response_schema=DocumentQueries,
tracer=tracer, tracer=tracer,
) )

View File

@@ -399,10 +399,15 @@ async def aget_data_sources_and_output_format(
agent_chat_model = agent.chat_model if agent else None agent_chat_model = agent.chat_model if agent else None
class PickTools(BaseModel):
source: List[str]
output: str
with timer("Chat actor: Infer information sources to refer", logger): with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
relevant_tools_prompt, relevant_tools_prompt,
response_type="json_object", response_type="json_object",
response_schema=PickTools,
user=user, user=user,
query_files=query_files, query_files=query_files,
agent_chat_model=agent_chat_model, agent_chat_model=agent_chat_model,
@@ -483,11 +488,15 @@ async def infer_webpage_urls(
agent_chat_model = agent.chat_model if agent else None agent_chat_model = agent.chat_model if agent else None
class WebpageUrls(BaseModel):
links: List[str]
with timer("Chat actor: Infer webpage urls to read", logger): with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
online_queries_prompt, online_queries_prompt,
query_images=query_images, query_images=query_images,
response_type="json_object", response_type="json_object",
response_schema=WebpageUrls,
user=user, user=user,
query_files=query_files, query_files=query_files,
agent_chat_model=agent_chat_model, agent_chat_model=agent_chat_model,
@@ -563,11 +572,13 @@ async def generate_online_subqueries(
response = pyjson5.loads(response) response = pyjson5.loads(response)
response = {q.strip() for q in response["queries"] if q.strip()} response = {q.strip() for q in response["queries"] if q.strip()}
if not isinstance(response, set) or not response or len(response) == 0: if not isinstance(response, set) or not response or len(response) == 0:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}") logger.error(
f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}"
)
return {q} return {q}
return response return response
except Exception as e: except Exception as e:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}") logger.error(f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}")
return {q} return {q}
@@ -2054,7 +2065,7 @@ def schedule_automation(
try: try:
user_timezone = pytz.timezone(timezone) user_timezone = pytz.timezone(timezone)
except pytz.UnknownTimeZoneError: except pytz.UnknownTimeZoneError:
logger.error(f"Invalid timezone: {timezone}. Fallback to use UTC to schedule automation.") logger.warning(f"Invalid timezone: {timezone}. Fallback to use UTC to schedule automation.")
user_timezone = pytz.utc user_timezone = pytz.utc
trigger = CronTrigger.from_crontab(crontime, user_timezone) trigger = CronTrigger.from_crontab(crontime, user_timezone)