mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Enforce json schema on more chat actors to improve schema compliance
Including infer webpage urls, gemini documents search, pick default mode tools chat actors
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
@@ -96,12 +97,16 @@ def extract_questions_gemini(
|
||||
messages.append(ChatMessage(content=prompt, role="user"))
|
||||
messages.append(ChatMessage(content=system_prompt, role="system"))
|
||||
|
||||
class DocumentQueries(BaseModel):
|
||||
queries: List[str]
|
||||
|
||||
response = gemini_send_message_to_model(
|
||||
messages,
|
||||
api_key,
|
||||
model,
|
||||
api_base_url=api_base_url,
|
||||
response_type="json_object",
|
||||
response_schema=DocumentQueries,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
@@ -399,10 +399,15 @@ async def aget_data_sources_and_output_format(
|
||||
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
|
||||
class PickTools(BaseModel):
|
||||
source: List[str]
|
||||
output: str
|
||||
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
relevant_tools_prompt,
|
||||
response_type="json_object",
|
||||
response_schema=PickTools,
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
agent_chat_model=agent_chat_model,
|
||||
@@ -483,11 +488,15 @@ async def infer_webpage_urls(
|
||||
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
|
||||
class WebpageUrls(BaseModel):
|
||||
links: List[str]
|
||||
|
||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
response_schema=WebpageUrls,
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
agent_chat_model=agent_chat_model,
|
||||
@@ -563,11 +572,13 @@ async def generate_online_subqueries(
|
||||
response = pyjson5.loads(response)
|
||||
response = {q.strip() for q in response["queries"] if q.strip()}
|
||||
if not isinstance(response, set) or not response or len(response) == 0:
|
||||
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
|
||||
logger.error(
|
||||
f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}"
|
||||
)
|
||||
return {q}
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
|
||||
logger.error(f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}")
|
||||
return {q}
|
||||
|
||||
|
||||
@@ -2054,7 +2065,7 @@ def schedule_automation(
|
||||
try:
|
||||
user_timezone = pytz.timezone(timezone)
|
||||
except pytz.UnknownTimeZoneError:
|
||||
logger.error(f"Invalid timezone: {timezone}. Fallback to use UTC to schedule automation.")
|
||||
logger.warning(f"Invalid timezone: {timezone}. Fallback to use UTC to schedule automation.")
|
||||
user_timezone = pytz.utc
|
||||
|
||||
trigger = CronTrigger.from_crontab(crontime, user_timezone)
|
||||
|
||||
Reference in New Issue
Block a user