Working prototype of meta-level chain of reasoning and execution

- Create a more dynamic reasoning agent that can evaluate information and understand what it doesn't know, making moves to get that information
- Lots of hacks and code that needs to be reversed later on before submission
This commit is contained in:
sabaimran
2024-10-09 15:54:25 -07:00
parent 00546c1a63
commit f867d5ed72
6 changed files with 906 additions and 531 deletions

View File

@@ -485,6 +485,47 @@ Khoj:
""".strip() """.strip()
) )
plan_function_execution = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful search assistant.
{personality_context}
- You have access to a variety of data sources to help you answer the user's question
- You can use the data sources listed below to collect more relevant information, one at a time
- You are given multiple iterations to with these data sources to answer the user's question
- You are provided with additional context. If you have enough context to answer the question, then exit execution
If you already know the answer to the question, return an empty response, e.g., {{}}.
Which of the data sources listed below you would use to answer the user's question? You **only** have access to the following data sources:
{tools}
Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data source and associated query in a JSON object. Do not say anything else.
Previous Iterations:
{previous_iterations}
Response format:
{{"data_source": "<tool_name>", "query": "<your_new_query>"}}
Chat History:
{chat_history}
Q: {query}
Khoj:
""".strip()
)
previous_iteration = PromptTemplate.from_template(
"""
data_source: {data_source}
query: {query}
context: {context}
onlineContext: {onlineContext}
---
""".strip()
)
pick_relevant_information_collection_tools = PromptTemplate.from_template( pick_relevant_information_collection_tools = PromptTemplate.from_template(
""" """
You are Khoj, an extremely smart and helpful search assistant. You are Khoj, an extremely smart and helpful search assistant.

View File

@@ -355,9 +355,10 @@ async def extract_references_and_questions(
agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent) agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent)
if ( if (
not ConversationCommand.Notes in conversation_commands # not ConversationCommand.Notes in conversation_commands
and not ConversationCommand.Default in conversation_commands # and not ConversationCommand.Default in conversation_commands
and not agent_has_entries # and not agent_has_entries
True
): ):
yield compiled_references, inferred_queries, q yield compiled_references, inferred_queries, q
return return

View File

@@ -41,7 +41,8 @@ from khoj.routers.helpers import (
aget_relevant_output_modes, aget_relevant_output_modes,
construct_automation_created_message, construct_automation_created_message,
create_automation, create_automation,
extract_relevant_summary, extract_relevant_info,
generate_summary_from_files,
get_conversation_command, get_conversation_command,
is_query_empty, is_query_empty,
is_ready_to_chat, is_ready_to_chat,
@@ -49,6 +50,10 @@ from khoj.routers.helpers import (
update_telemetry_state, update_telemetry_state,
validate_conversation_config, validate_conversation_config,
) )
from khoj.routers.research import (
InformationCollectionIteration,
execute_information_collection,
)
from khoj.routers.storage import upload_image_to_bucket from khoj.routers.storage import upload_image_to_bucket
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import ( from khoj.utils.helpers import (
@@ -689,7 +694,46 @@ async def chat(
meta_log = conversation.conversation_log meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
pending_research = True
researched_results = ""
online_results: Dict = dict()
## Extract Document References
compiled_references, inferred_queries, defiltered_query = [], [], None
if conversation_commands == [ConversationCommand.Default] or is_automated_task: if conversation_commands == [ConversationCommand.Default] or is_automated_task:
async for research_result in execute_information_collection(
request=request,
user=user,
query=q,
conversation_id=conversation_id,
conversation_history=meta_log,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
location=location,
file_filters=conversation.file_filters if conversation else [],
):
if type(research_result) == InformationCollectionIteration:
pending_research = False
if research_result.onlineContext:
researched_results += str(research_result.onlineContext)
online_results.update(research_result.onlineContext)
if research_result.context:
researched_results += str(research_result.context)
compiled_references.extend(research_result.context)
else:
yield research_result
researched_results = await extract_relevant_info(q, researched_results, agent)
logger.info(f"Researched Results: {researched_results}")
pending_research = False
conversation_commands = await aget_relevant_information_sources( conversation_commands = await aget_relevant_information_sources(
q, q,
meta_log, meta_log,
@@ -724,9 +768,11 @@ async def chat(
and not used_slash_summarize and not used_slash_summarize
# but we can't actually summarize # but we can't actually summarize
and len(file_filters) != 1 and len(file_filters) != 1
# not pending research
and not pending_research
): ):
conversation_commands.remove(ConversationCommand.Summarize) conversation_commands.remove(ConversationCommand.Summarize)
elif ConversationCommand.Summarize in conversation_commands: elif ConversationCommand.Summarize in conversation_commands and pending_research:
response_log = "" response_log = ""
agent_has_entries = await EntryAdapters.aagent_has_entries(agent) agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
if len(file_filters) == 0 and not agent_has_entries: if len(file_filters) == 0 and not agent_has_entries:
@@ -738,47 +784,15 @@ async def chat(
async for result in send_llm_response(response_log): async for result in send_llm_response(response_log):
yield result yield result
else: else:
try: response_log = await generate_summary_from_files(
file_object = None q=query,
if await EntryAdapters.aagent_has_entries(agent): user=user,
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) file_filters=file_filters,
if len(file_names) > 0: meta_log=meta_log,
file_object = await FileObjectAdapters.async_get_file_objects_by_name(
None, file_names[0], agent
)
if len(file_filters) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
async for result in send_llm_response(response_log):
yield result
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
async for result in send_event(
ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}"
):
yield result
response = await extract_relevant_summary(
q,
contextual_data,
conversation_history=meta_log,
subscribed=subscribed, subscribed=subscribed,
uploaded_image_url=uploaded_image_url, send_status_func=partial(send_event, ChatEvent.STATUS),
agent=agent, send_response_func=partial(send_llm_response),
) )
response_log = str(response)
async for result in send_llm_response(response_log):
yield result
except Exception as e:
response_log = "Error summarizing file. Please try again, or contact support."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_llm_response(response_log):
yield result
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
q, q,
response_log, response_log,
@@ -838,8 +852,6 @@ async def chat(
return return
# Gather Context # Gather Context
## Extract Document References
compiled_references, inferred_queries, defiltered_query = [], [], None
async for result in extract_references_and_questions( async for result in extract_references_and_questions(
request, request,
meta_log, meta_log,
@@ -867,8 +879,6 @@ async def chat(
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
yield result yield result
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
async for result in send_llm_response(f"{no_entries_found.format()}"): async for result in send_llm_response(f"{no_entries_found.format()}"):
yield result yield result
@@ -878,7 +888,7 @@ async def chat(
conversation_commands.remove(ConversationCommand.Notes) conversation_commands.remove(ConversationCommand.Notes)
## Gather Online References ## Gather Online References
if ConversationCommand.Online in conversation_commands: if ConversationCommand.Online in conversation_commands and pending_research:
try: try:
async for result in search_online( async for result in search_online(
defiltered_query, defiltered_query,
@@ -903,7 +913,7 @@ async def chat(
return return
## Gather Webpage References ## Gather Webpage References
if ConversationCommand.Webpage in conversation_commands: if ConversationCommand.Webpage in conversation_commands and pending_research:
try: try:
async for result in read_webpages( async for result in read_webpages(
defiltered_query, defiltered_query,
@@ -1008,6 +1018,7 @@ async def chat(
defiltered_query, defiltered_query,
meta_log, meta_log,
conversation, conversation,
researched_results,
compiled_references, compiled_references,
online_results, online_results,
inferred_queries, inferred_queries,
@@ -1051,36 +1062,32 @@ async def chat(
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
# Deprecated API. Remove by end of September 2024 # @api_chat.post("")
@api_chat.get("")
@requires(["authenticated"]) @requires(["authenticated"])
async def get_chat( async def old_chat(
request: Request, request: Request,
common: CommonQueryParams, common: CommonQueryParams,
q: str, body: ChatRequestBody,
n: int = 7,
d: float = None,
stream: Optional[bool] = False,
title: Optional[str] = None,
conversation_id: Optional[str] = None,
city: Optional[str] = None,
region: Optional[str] = None,
country: Optional[str] = None,
timezone: Optional[str] = None,
image: Optional[str] = None,
rate_limiter_per_minute=Depends( rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=60, subscribed_requests=60, window=60, slug="chat_minute") ApiUserRateLimiter(requests=60, subscribed_requests=200, window=60, slug="chat_minute")
), ),
rate_limiter_per_day=Depends( rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=600, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
), ),
): ):
# Issue a deprecation warning # Access the parameters from the body
warnings.warn( q = body.q
"The 'get_chat' API endpoint is deprecated. It will be removed by the end of September 2024.", n = body.n
DeprecationWarning, d = body.d
stacklevel=2, stream = body.stream
) title = body.title
conversation_id = body.conversation_id
city = body.city
region = body.region
country = body.country or get_country_name_from_timezone(body.timezone)
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
timezone = body.timezone
image = body.image
async def event_generator(q: str, image: str): async def event_generator(q: str, image: str):
start_time = time.perf_counter() start_time = time.perf_counter()
@@ -1108,7 +1115,7 @@ async def get_chat(
nonlocal connection_alive, ttft nonlocal connection_alive, ttft
if not connection_alive or await request.is_disconnected(): if not connection_alive or await request.is_disconnected():
connection_alive = False connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client") logger.warning(f"User {user} disconnected from {common.client} client")
return return
try: try:
if event_type == ChatEvent.END_LLM_RESPONSE: if event_type == ChatEvent.END_LLM_RESPONSE:
@@ -1164,21 +1171,34 @@ async def get_chat(
conversation_commands = [get_conversation_command(query=q, any_references=True)] conversation_commands = [get_conversation_command(query=q, any_references=True)]
conversation = await ConversationAdapters.aget_conversation_by_user( conversation = await ConversationAdapters.aget_conversation_by_user(
user, client_application=request.user.client_app, conversation_id=conversation_id, title=title user,
client_application=request.user.client_app,
conversation_id=conversation_id,
title=title,
create_new=body.create_new,
) )
if not conversation: if not conversation:
async for result in send_llm_response(f"Conversation {conversation_id} not found"): async for result in send_llm_response(f"Conversation {conversation_id} not found"):
yield result yield result
return return
conversation_id = conversation.id conversation_id = conversation.id
agent = conversation.agent if conversation.agent else None
agent: Agent | None = None
default_agent = await AgentAdapters.aget_default_agent()
if conversation.agent and conversation.agent != default_agent:
agent = conversation.agent
if not conversation.agent:
conversation.agent = default_agent
await conversation.asave()
agent = default_agent
await is_ready_to_chat(user) await is_ready_to_chat(user)
user_name = await aget_user_name(user) user_name = await aget_user_name(user)
location = None location = None
if city or region or country: if city or region or country or country_code:
location = LocationData(city=city, region=region, country=country) location = LocationData(city=city, region=region, country=country, country_code=country_code)
if is_query_empty(q): if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started."): async for result in send_llm_response("Please ask your query to get started."):
@@ -1192,7 +1212,12 @@ async def get_chat(
if conversation_commands == [ConversationCommand.Default] or is_automated_task: if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources( conversation_commands = await aget_relevant_information_sources(
q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url q,
meta_log,
is_automated_task,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
agent=agent,
) )
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event( async for result in send_event(
@@ -1200,7 +1225,7 @@ async def get_chat(
): ):
yield result yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url) mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result yield result
if mode not in conversation_commands: if mode not in conversation_commands:
@@ -1224,45 +1249,25 @@ async def get_chat(
conversation_commands.remove(ConversationCommand.Summarize) conversation_commands.remove(ConversationCommand.Summarize)
elif ConversationCommand.Summarize in conversation_commands: elif ConversationCommand.Summarize in conversation_commands:
response_log = "" response_log = ""
if len(file_filters) == 0: agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
if len(file_filters) == 0 and not agent_has_entries:
response_log = "No files selected for summarization. Please add files using the section on the left." response_log = "No files selected for summarization. Please add files using the section on the left."
async for result in send_llm_response(response_log): async for result in send_llm_response(response_log):
yield result yield result
elif len(file_filters) > 1: elif len(file_filters) > 1 and not agent_has_entries:
response_log = "Only one file can be selected for summarization." response_log = "Only one file can be selected for summarization."
async for result in send_llm_response(response_log): async for result in send_llm_response(response_log):
yield result yield result
else: else:
try: response_log = await generate_summary_from_files(
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) q=query,
if len(file_object) == 0: user=user,
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." file_filters=file_filters,
async for result in send_llm_response(response_log): meta_log=meta_log,
yield result
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
async for result in send_event(
ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}"
):
yield result
response = await extract_relevant_summary(
q,
contextual_data,
conversation_history=meta_log,
subscribed=subscribed, subscribed=subscribed,
uploaded_image_url=uploaded_image_url, send_status_func=partial(send_event, ChatEvent.STATUS),
send_response_func=partial(send_llm_response),
) )
response_log = str(response)
async for result in send_llm_response(response_log):
yield result
except Exception as e:
response_log = "Error summarizing file."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_llm_response(response_log):
yield result
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
q, q,
response_log, response_log,
@@ -1335,6 +1340,7 @@ async def get_chat(
location, location,
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url, uploaded_image_url=uploaded_image_url,
agent=agent,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@@ -1350,8 +1356,6 @@ async def get_chat(
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
yield result yield result
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
async for result in send_llm_response(f"{no_entries_found.format()}"): async for result in send_llm_response(f"{no_entries_found.format()}"):
yield result yield result
@@ -1372,6 +1376,7 @@ async def get_chat(
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
custom_filters, custom_filters,
uploaded_image_url=uploaded_image_url, uploaded_image_url=uploaded_image_url,
agent=agent,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@@ -1395,6 +1400,7 @@ async def get_chat(
subscribed, subscribed,
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url, uploaded_image_url=uploaded_image_url,
agent=agent,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@@ -1441,6 +1447,7 @@ async def get_chat(
subscribed=subscribed, subscribed=subscribed,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url, uploaded_image_url=uploaded_image_url,
agent=agent,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]

View File

@@ -14,6 +14,7 @@ from typing import (
Annotated, Annotated,
Any, Any,
AsyncGenerator, AsyncGenerator,
Callable,
Dict, Dict,
Iterator, Iterator,
List, List,
@@ -39,6 +40,7 @@ from khoj.database.adapters import (
AutomationAdapters, AutomationAdapters,
ConversationAdapters, ConversationAdapters,
EntryAdapters, EntryAdapters,
FileObjectAdapters,
create_khoj_token, create_khoj_token,
get_khoj_tokens, get_khoj_tokens,
get_user_name, get_user_name,
@@ -614,6 +616,58 @@ async def extract_relevant_summary(
return response.strip() return response.strip()
async def generate_summary_from_files(
q: str,
user: KhojUser,
file_filters: List[str],
meta_log: dict,
subscribed: bool,
uploaded_image_url: str = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
send_response_func: Optional[Callable] = None,
):
try:
file_object = None
if await EntryAdapters.aagent_has_entries(agent):
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
if len(file_names) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names[0], agent)
if len(file_filters) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = (
"Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
)
async for result in send_response_func(response_log):
yield result
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"):
yield result
response = await extract_relevant_summary(
q,
contextual_data,
conversation_history=meta_log,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
agent=agent,
)
response_log = str(response)
async for result in send_response_func(response_log):
yield result
except Exception as e:
response_log = "Error summarizing file. Please try again, or contact support."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_response_func(response_log):
yield result
async def generate_better_image_prompt( async def generate_better_image_prompt(
q: str, q: str,
conversation_history: str, conversation_history: str,
@@ -893,6 +947,7 @@ def generate_chat_response(
q: str, q: str,
meta_log: dict, meta_log: dict,
conversation: Conversation, conversation: Conversation,
meta_research: str = "",
compiled_references: List[Dict] = [], compiled_references: List[Dict] = [],
online_results: Dict[str, Dict] = {}, online_results: Dict[str, Dict] = {},
inferred_queries: List[str] = [], inferred_queries: List[str] = [],
@@ -910,6 +965,9 @@ def generate_chat_response(
metadata = {} metadata = {}
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
query_to_run = q
if meta_research:
query_to_run = f"AI Research: {meta_research} {q}"
try: try:
partial_completion = partial( partial_completion = partial(
save_to_conversation_log, save_to_conversation_log,
@@ -937,7 +995,7 @@ def generate_chat_response(
chat_response = converse_offline( chat_response = converse_offline(
references=compiled_references, references=compiled_references,
online_results=online_results, online_results=online_results,
user_query=q, user_query=query_to_run,
loaded_model=loaded_model, loaded_model=loaded_model,
conversation_log=meta_log, conversation_log=meta_log,
completion_func=partial_completion, completion_func=partial_completion,
@@ -956,7 +1014,7 @@ def generate_chat_response(
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
chat_response = converse( chat_response = converse(
compiled_references, compiled_references,
q, query_to_run,
image_url=uploaded_image_url, image_url=uploaded_image_url,
online_results=online_results, online_results=online_results,
conversation_log=meta_log, conversation_log=meta_log,
@@ -977,7 +1035,7 @@ def generate_chat_response(
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key
chat_response = converse_anthropic( chat_response = converse_anthropic(
compiled_references, compiled_references,
q, query_to_run,
online_results, online_results,
meta_log, meta_log,
model=conversation_config.chat_model, model=conversation_config.chat_model,
@@ -994,7 +1052,7 @@ def generate_chat_response(
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key
chat_response = converse_gemini( chat_response = converse_gemini(
compiled_references, compiled_references,
q, query_to_run,
online_results, online_results,
meta_log, meta_log,
model=conversation_config.chat_model, model=conversation_config.chat_model,

View File

@@ -0,0 +1,261 @@
import json
import logging
from typing import Any, Callable, Dict, List, Optional
from fastapi import Request
from khoj.database.adapters import EntryAdapters
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import remove_json_codeblock
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import (
ChatEvent,
construct_chat_history,
generate_summary_from_files,
send_message_to_model_wrapper,
)
from khoj.utils.helpers import (
ConversationCommand,
function_calling_description_for_llm,
timer,
)
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
class InformationCollectionIteration:
def __init__(
self, data_source: str, query: str, context: str = None, onlineContext: str = None, result: Any = None
):
self.data_source = data_source
self.query = query
self.context = context
self.onlineContext = onlineContext
async def apick_next_tool(
query: str,
conversation_history: dict,
subscribed: bool,
uploaded_image_url: str = None,
agent: Agent = None,
previous_iterations: List[InformationCollectionIteration] = None,
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
"""
tool_options = dict()
tool_options_str = ""
agent_tools = agent.input_tools if agent else []
for tool, description in function_calling_description_for_llm.items():
tool_options[tool.value] = description
if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options_str += f'- "{tool.value}": "{description}"\n'
chat_history = construct_chat_history(conversation_history)
previous_iterations_history = ""
for iteration in previous_iterations:
iteration_data = prompts.previous_iteration.format(
query=iteration.query,
data_source=iteration.data_source,
context=str(iteration.context),
onlineContext=str(iteration.onlineContext),
)
previous_iterations_history += iteration_data
if uploaded_image_url:
query = f"[placeholder for user attached image]\n{query}"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
function_planning_prompt = prompts.plan_function_execution.format(
query=query,
tools=tool_options_str,
chat_history=chat_history,
personality_context=personality_context,
previous_iterations=previous_iterations_history,
)
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
function_planning_prompt,
response_type="json_object",
subscribed=subscribed,
)
try:
response = response.strip()
response = remove_json_codeblock(response)
response = json.loads(response)
suggested_data_source = response.get("data_source", None)
suggested_query = response.get("query", None)
return InformationCollectionIteration(
data_source=suggested_data_source,
query=suggested_query,
)
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
return InformationCollectionIteration(
data_source=None,
query=None,
)
async def execute_information_collection(
request: Request,
user: KhojUser,
query: str,
conversation_id: str,
conversation_history: dict,
subscribed: bool,
uploaded_image_url: str = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
location: LocationData = None,
file_filters: List[str] = [],
):
iteration = 0
MAX_ITERATIONS = 2
previous_iterations = []
while iteration < MAX_ITERATIONS:
online_results: Dict = dict()
compiled_references, inferred_queries, defiltered_query = [], [], None
this_iteration = await apick_next_tool(
query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations
)
if this_iteration.data_source == ConversationCommand.Notes:
## Extract Document References
compiled_references, inferred_queries, defiltered_query = [], [], None
async for result in extract_references_and_questions(
request,
conversation_history,
this_iteration.query,
7,
None,
conversation_id,
[ConversationCommand.Default],
location,
send_status_func,
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
context=str(compiled_references),
)
)
elif this_iteration.data_source == ConversationCommand.Online:
async for result in search_online(
this_iteration.query,
conversation_history,
location,
user,
subscribed,
send_status_func,
[],
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
online_results = result
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
onlineContext=online_results,
)
)
elif this_iteration.data_source == ConversationCommand.Webpage:
async for result in read_webpages(
this_iteration.query,
conversation_history,
location,
user,
subscribed,
send_status_func,
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
direct_web_pages = result
webpages = []
for query in direct_web_pages:
if online_results.get(query):
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
else:
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"])
yield send_status_func(f"**Read web pages**: {webpages}")
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
onlineContext=online_results,
)
)
elif this_iteration.data_source == ConversationCommand.Summarize:
response_log = ""
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
if len(file_filters) == 0 and not agent_has_entries:
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
context="No files selected for summarization.",
)
)
elif len(file_filters) > 1 and not agent_has_entries:
response_log = "Only one file can be selected for summarization."
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
context=response_log,
)
)
else:
response_log = await generate_summary_from_files(
q=query,
user=user,
file_filters=file_filters,
meta_log=conversation_history,
subscribed=subscribed,
send_status_func=send_status_func,
)
else:
iteration = MAX_ITERATIONS
iteration += 1
for completed_iter in previous_iterations:
yield completed_iter

View File

@@ -345,6 +345,13 @@ tool_descriptions_for_llm = {
ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.", ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.",
} }
function_calling_description_for_llm = {
ConversationCommand.Notes: "Use this if you think the user's personal knowledge base contains relevant context.",
ConversationCommand.Online: "Use this if you think the there's important information on the internet related to the query.",
ConversationCommand.Webpage: "Use this if the user has provided a webpage URL or you are share of a webpage URL that will help you directly answer this query",
ConversationCommand.Summarize: "Use this if you want to retrieve an answer that depends on reading an entire corpus.",
}
mode_descriptions_for_llm = { mode_descriptions_for_llm = {
ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.", ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.",
ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.", ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.",