mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 05:39:12 +00:00
Enforce json schema on more chat actors to improve schema compliance
Including infer webpage urls, gemini documents search, pick default mode tools chat actors
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
|
|||||||
|
|
||||||
import pyjson5
|
import pyjson5
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
@@ -96,12 +97,16 @@ def extract_questions_gemini(
|
|||||||
messages.append(ChatMessage(content=prompt, role="user"))
|
messages.append(ChatMessage(content=prompt, role="user"))
|
||||||
messages.append(ChatMessage(content=system_prompt, role="system"))
|
messages.append(ChatMessage(content=system_prompt, role="system"))
|
||||||
|
|
||||||
|
class DocumentQueries(BaseModel):
|
||||||
|
queries: List[str]
|
||||||
|
|
||||||
response = gemini_send_message_to_model(
|
response = gemini_send_message_to_model(
|
||||||
messages,
|
messages,
|
||||||
api_key,
|
api_key,
|
||||||
model,
|
model,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
|
response_schema=DocumentQueries,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -399,10 +399,15 @@ async def aget_data_sources_and_output_format(
|
|||||||
|
|
||||||
agent_chat_model = agent.chat_model if agent else None
|
agent_chat_model = agent.chat_model if agent else None
|
||||||
|
|
||||||
|
class PickTools(BaseModel):
|
||||||
|
source: List[str]
|
||||||
|
output: str
|
||||||
|
|
||||||
with timer("Chat actor: Infer information sources to refer", logger):
|
with timer("Chat actor: Infer information sources to refer", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
relevant_tools_prompt,
|
relevant_tools_prompt,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
|
response_schema=PickTools,
|
||||||
user=user,
|
user=user,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
agent_chat_model=agent_chat_model,
|
agent_chat_model=agent_chat_model,
|
||||||
@@ -483,11 +488,15 @@ async def infer_webpage_urls(
|
|||||||
|
|
||||||
agent_chat_model = agent.chat_model if agent else None
|
agent_chat_model = agent.chat_model if agent else None
|
||||||
|
|
||||||
|
class WebpageUrls(BaseModel):
|
||||||
|
links: List[str]
|
||||||
|
|
||||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt,
|
online_queries_prompt,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
|
response_schema=WebpageUrls,
|
||||||
user=user,
|
user=user,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
agent_chat_model=agent_chat_model,
|
agent_chat_model=agent_chat_model,
|
||||||
@@ -563,11 +572,13 @@ async def generate_online_subqueries(
|
|||||||
response = pyjson5.loads(response)
|
response = pyjson5.loads(response)
|
||||||
response = {q.strip() for q in response["queries"] if q.strip()}
|
response = {q.strip() for q in response["queries"] if q.strip()}
|
||||||
if not isinstance(response, set) or not response or len(response) == 0:
|
if not isinstance(response, set) or not response or len(response) == 0:
|
||||||
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
|
logger.error(
|
||||||
|
f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}"
|
||||||
|
)
|
||||||
return {q}
|
return {q}
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
|
logger.error(f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}")
|
||||||
return {q}
|
return {q}
|
||||||
|
|
||||||
|
|
||||||
@@ -2054,7 +2065,7 @@ def schedule_automation(
|
|||||||
try:
|
try:
|
||||||
user_timezone = pytz.timezone(timezone)
|
user_timezone = pytz.timezone(timezone)
|
||||||
except pytz.UnknownTimeZoneError:
|
except pytz.UnknownTimeZoneError:
|
||||||
logger.error(f"Invalid timezone: {timezone}. Fallback to use UTC to schedule automation.")
|
logger.warning(f"Invalid timezone: {timezone}. Fallback to use UTC to schedule automation.")
|
||||||
user_timezone = pytz.utc
|
user_timezone = pytz.utc
|
||||||
|
|
||||||
trigger = CronTrigger.from_crontab(crontime, user_timezone)
|
trigger = CronTrigger.from_crontab(crontime, user_timezone)
|
||||||
|
|||||||
Reference in New Issue
Block a user