From a387f638cddaa2297eadd0fa54e1764670b1757a Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 27 Mar 2025 18:33:08 +0530 Subject: [PATCH] Enforce json schema on more chat actors to improve schema compliance Including infer webpage urls, gemini documents search, pick default mode tools chat actors --- .../conversation/google/gemini_chat.py | 5 +++++ src/khoj/routers/helpers.py | 17 ++++++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 3c630dec..eec24d92 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage +from pydantic import BaseModel from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts @@ -96,12 +97,16 @@ def extract_questions_gemini( messages.append(ChatMessage(content=prompt, role="user")) messages.append(ChatMessage(content=system_prompt, role="system")) + class DocumentQueries(BaseModel): + queries: List[str] + response = gemini_send_message_to_model( messages, api_key, model, api_base_url=api_base_url, response_type="json_object", + response_schema=DocumentQueries, tracer=tracer, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index acca60b0..28808ecd 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -399,10 +399,15 @@ async def aget_data_sources_and_output_format( agent_chat_model = agent.chat_model if agent else None + class PickTools(BaseModel): + source: List[str] + output: str + with timer("Chat actor: Infer information sources to refer", logger): response = await send_message_to_model_wrapper( relevant_tools_prompt, response_type="json_object", + response_schema=PickTools, user=user, query_files=query_files, agent_chat_model=agent_chat_model, @@ -483,11 +488,15 @@ async def infer_webpage_urls( agent_chat_model = agent.chat_model if agent else None + class WebpageUrls(BaseModel): + links: List[str] + with timer("Chat actor: Infer webpage urls to read", logger): response = await send_message_to_model_wrapper( online_queries_prompt, query_images=query_images, response_type="json_object", + response_schema=WebpageUrls, user=user, query_files=query_files, agent_chat_model=agent_chat_model, @@ -563,11 +572,13 @@ async def generate_online_subqueries( response = pyjson5.loads(response) response = {q.strip() for q in response["queries"] if q.strip()} if not isinstance(response, set) or not response or len(response) == 0: - logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}") + logger.error( + f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}" + ) return {q} return response except Exception as e: - logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}") + logger.error(f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}") return {q} @@ -2054,7 +2065,7 @@ def schedule_automation( try: user_timezone = pytz.timezone(timezone) except pytz.UnknownTimeZoneError: - logger.error(f"Invalid timezone: {timezone}. Fallback to use UTC to schedule automation.") + logger.warning(f"Invalid timezone: {timezone}. Fallback to use UTC to schedule automation.") user_timezone = pytz.utc trigger = CronTrigger.from_crontab(crontime, user_timezone)