mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-10 13:26:13 +00:00
Working prototype of meta-level chain of reasoning and execution
- Create a more dynamic reasoning agent that can evaluate information and understand what it doesn't know, making moves to get that information - Lots of hacks and code that needs to be reversed later on before submission
This commit is contained in:
@@ -485,6 +485,47 @@ Khoj:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
plan_function_execution = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful search assistant.
|
||||
{personality_context}
|
||||
- You have access to a variety of data sources to help you answer the user's question
|
||||
- You can use the data sources listed below to collect more relevant information, one at a time
|
||||
- You are given multiple iterations to with these data sources to answer the user's question
|
||||
- You are provided with additional context. If you have enough context to answer the question, then exit execution
|
||||
|
||||
If you already know the answer to the question, return an empty response, e.g., {{}}.
|
||||
|
||||
Which of the data sources listed below you would use to answer the user's question? You **only** have access to the following data sources:
|
||||
|
||||
{tools}
|
||||
|
||||
Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data source and associated query in a JSON object. Do not say anything else.
|
||||
|
||||
Previous Iterations:
|
||||
{previous_iterations}
|
||||
|
||||
Response format:
|
||||
{{"data_source": "<tool_name>", "query": "<your_new_query>"}}
|
||||
|
||||
Chat History:
|
||||
{chat_history}
|
||||
|
||||
Q: {query}
|
||||
Khoj:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
previous_iteration = PromptTemplate.from_template(
|
||||
"""
|
||||
data_source: {data_source}
|
||||
query: {query}
|
||||
context: {context}
|
||||
onlineContext: {onlineContext}
|
||||
---
|
||||
""".strip()
|
||||
)
|
||||
|
||||
pick_relevant_information_collection_tools = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful search assistant.
|
||||
|
||||
@@ -355,9 +355,10 @@ async def extract_references_and_questions(
|
||||
agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent)
|
||||
|
||||
if (
|
||||
not ConversationCommand.Notes in conversation_commands
|
||||
and not ConversationCommand.Default in conversation_commands
|
||||
and not agent_has_entries
|
||||
# not ConversationCommand.Notes in conversation_commands
|
||||
# and not ConversationCommand.Default in conversation_commands
|
||||
# and not agent_has_entries
|
||||
True
|
||||
):
|
||||
yield compiled_references, inferred_queries, q
|
||||
return
|
||||
|
||||
@@ -41,7 +41,8 @@ from khoj.routers.helpers import (
|
||||
aget_relevant_output_modes,
|
||||
construct_automation_created_message,
|
||||
create_automation,
|
||||
extract_relevant_summary,
|
||||
extract_relevant_info,
|
||||
generate_summary_from_files,
|
||||
get_conversation_command,
|
||||
is_query_empty,
|
||||
is_ready_to_chat,
|
||||
@@ -49,6 +50,10 @@ from khoj.routers.helpers import (
|
||||
update_telemetry_state,
|
||||
validate_conversation_config,
|
||||
)
|
||||
from khoj.routers.research import (
|
||||
InformationCollectionIteration,
|
||||
execute_information_collection,
|
||||
)
|
||||
from khoj.routers.storage import upload_image_to_bucket
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import (
|
||||
@@ -689,7 +694,46 @@ async def chat(
|
||||
meta_log = conversation.conversation_log
|
||||
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
||||
|
||||
pending_research = True
|
||||
|
||||
researched_results = ""
|
||||
online_results: Dict = dict()
|
||||
## Extract Document References
|
||||
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||
async for research_result in execute_information_collection(
|
||||
request=request,
|
||||
user=user,
|
||||
query=q,
|
||||
conversation_id=conversation_id,
|
||||
conversation_history=meta_log,
|
||||
subscribed=subscribed,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
location=location,
|
||||
file_filters=conversation.file_filters if conversation else [],
|
||||
):
|
||||
if type(research_result) == InformationCollectionIteration:
|
||||
pending_research = False
|
||||
if research_result.onlineContext:
|
||||
researched_results += str(research_result.onlineContext)
|
||||
online_results.update(research_result.onlineContext)
|
||||
|
||||
if research_result.context:
|
||||
researched_results += str(research_result.context)
|
||||
compiled_references.extend(research_result.context)
|
||||
|
||||
else:
|
||||
yield research_result
|
||||
|
||||
researched_results = await extract_relevant_info(q, researched_results, agent)
|
||||
|
||||
logger.info(f"Researched Results: {researched_results}")
|
||||
|
||||
pending_research = False
|
||||
|
||||
conversation_commands = await aget_relevant_information_sources(
|
||||
q,
|
||||
meta_log,
|
||||
@@ -724,9 +768,11 @@ async def chat(
|
||||
and not used_slash_summarize
|
||||
# but we can't actually summarize
|
||||
and len(file_filters) != 1
|
||||
# not pending research
|
||||
and not pending_research
|
||||
):
|
||||
conversation_commands.remove(ConversationCommand.Summarize)
|
||||
elif ConversationCommand.Summarize in conversation_commands:
|
||||
elif ConversationCommand.Summarize in conversation_commands and pending_research:
|
||||
response_log = ""
|
||||
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
||||
if len(file_filters) == 0 and not agent_has_entries:
|
||||
@@ -738,47 +784,15 @@ async def chat(
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
else:
|
||||
try:
|
||||
file_object = None
|
||||
if await EntryAdapters.aagent_has_entries(agent):
|
||||
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
|
||||
if len(file_names) > 0:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(
|
||||
None, file_names[0], agent
|
||||
)
|
||||
|
||||
if len(file_filters) > 0:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||
|
||||
if len(file_object) == 0:
|
||||
response_log = "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
return
|
||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||
if not q:
|
||||
q = "Create a general summary of the file"
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}"
|
||||
):
|
||||
yield result
|
||||
|
||||
response = await extract_relevant_summary(
|
||||
q,
|
||||
contextual_data,
|
||||
conversation_history=meta_log,
|
||||
response_log = await generate_summary_from_files(
|
||||
q=query,
|
||||
user=user,
|
||||
file_filters=file_filters,
|
||||
meta_log=meta_log,
|
||||
subscribed=subscribed,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
send_response_func=partial(send_llm_response),
|
||||
)
|
||||
response_log = str(response)
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
except Exception as e:
|
||||
response_log = "Error summarizing file. Please try again, or contact support."
|
||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
response_log,
|
||||
@@ -838,8 +852,6 @@ async def chat(
|
||||
return
|
||||
|
||||
# Gather Context
|
||||
## Extract Document References
|
||||
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
meta_log,
|
||||
@@ -867,8 +879,6 @@ async def chat(
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
|
||||
yield result
|
||||
|
||||
online_results: Dict = dict()
|
||||
|
||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
||||
yield result
|
||||
@@ -878,7 +888,7 @@ async def chat(
|
||||
conversation_commands.remove(ConversationCommand.Notes)
|
||||
|
||||
## Gather Online References
|
||||
if ConversationCommand.Online in conversation_commands:
|
||||
if ConversationCommand.Online in conversation_commands and pending_research:
|
||||
try:
|
||||
async for result in search_online(
|
||||
defiltered_query,
|
||||
@@ -903,7 +913,7 @@ async def chat(
|
||||
return
|
||||
|
||||
## Gather Webpage References
|
||||
if ConversationCommand.Webpage in conversation_commands:
|
||||
if ConversationCommand.Webpage in conversation_commands and pending_research:
|
||||
try:
|
||||
async for result in read_webpages(
|
||||
defiltered_query,
|
||||
@@ -1008,6 +1018,7 @@ async def chat(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
conversation,
|
||||
researched_results,
|
||||
compiled_references,
|
||||
online_results,
|
||||
inferred_queries,
|
||||
@@ -1051,36 +1062,32 @@ async def chat(
|
||||
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
# Deprecated API. Remove by end of September 2024
|
||||
@api_chat.get("")
|
||||
# @api_chat.post("")
|
||||
@requires(["authenticated"])
|
||||
async def get_chat(
|
||||
async def old_chat(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
q: str,
|
||||
n: int = 7,
|
||||
d: float = None,
|
||||
stream: Optional[bool] = False,
|
||||
title: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
city: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
country: Optional[str] = None,
|
||||
timezone: Optional[str] = None,
|
||||
image: Optional[str] = None,
|
||||
body: ChatRequestBody,
|
||||
rate_limiter_per_minute=Depends(
|
||||
ApiUserRateLimiter(requests=60, subscribed_requests=60, window=60, slug="chat_minute")
|
||||
ApiUserRateLimiter(requests=60, subscribed_requests=200, window=60, slug="chat_minute")
|
||||
),
|
||||
rate_limiter_per_day=Depends(
|
||||
ApiUserRateLimiter(requests=600, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
|
||||
),
|
||||
):
|
||||
# Issue a deprecation warning
|
||||
warnings.warn(
|
||||
"The 'get_chat' API endpoint is deprecated. It will be removed by the end of September 2024.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# Access the parameters from the body
|
||||
q = body.q
|
||||
n = body.n
|
||||
d = body.d
|
||||
stream = body.stream
|
||||
title = body.title
|
||||
conversation_id = body.conversation_id
|
||||
city = body.city
|
||||
region = body.region
|
||||
country = body.country or get_country_name_from_timezone(body.timezone)
|
||||
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
||||
timezone = body.timezone
|
||||
image = body.image
|
||||
|
||||
async def event_generator(q: str, image: str):
|
||||
start_time = time.perf_counter()
|
||||
@@ -1108,7 +1115,7 @@ async def get_chat(
|
||||
nonlocal connection_alive, ttft
|
||||
if not connection_alive or await request.is_disconnected():
|
||||
connection_alive = False
|
||||
logger.warn(f"User {user} disconnected from {common.client} client")
|
||||
logger.warning(f"User {user} disconnected from {common.client} client")
|
||||
return
|
||||
try:
|
||||
if event_type == ChatEvent.END_LLM_RESPONSE:
|
||||
@@ -1164,21 +1171,34 @@ async def get_chat(
|
||||
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||
user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
|
||||
user,
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
title=title,
|
||||
create_new=body.create_new,
|
||||
)
|
||||
if not conversation:
|
||||
async for result in send_llm_response(f"Conversation {conversation_id} not found"):
|
||||
yield result
|
||||
return
|
||||
conversation_id = conversation.id
|
||||
agent = conversation.agent if conversation.agent else None
|
||||
|
||||
agent: Agent | None = None
|
||||
default_agent = await AgentAdapters.aget_default_agent()
|
||||
if conversation.agent and conversation.agent != default_agent:
|
||||
agent = conversation.agent
|
||||
|
||||
if not conversation.agent:
|
||||
conversation.agent = default_agent
|
||||
await conversation.asave()
|
||||
agent = default_agent
|
||||
|
||||
await is_ready_to_chat(user)
|
||||
|
||||
user_name = await aget_user_name(user)
|
||||
location = None
|
||||
if city or region or country:
|
||||
location = LocationData(city=city, region=region, country=country)
|
||||
if city or region or country or country_code:
|
||||
location = LocationData(city=city, region=region, country=country, country_code=country_code)
|
||||
|
||||
if is_query_empty(q):
|
||||
async for result in send_llm_response("Please ask your query to get started."):
|
||||
@@ -1192,7 +1212,12 @@ async def get_chat(
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||
conversation_commands = await aget_relevant_information_sources(
|
||||
q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url
|
||||
q,
|
||||
meta_log,
|
||||
is_automated_task,
|
||||
subscribed=subscribed,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
)
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
async for result in send_event(
|
||||
@@ -1200,7 +1225,7 @@ async def get_chat(
|
||||
):
|
||||
yield result
|
||||
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url)
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent)
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||
yield result
|
||||
if mode not in conversation_commands:
|
||||
@@ -1224,45 +1249,25 @@ async def get_chat(
|
||||
conversation_commands.remove(ConversationCommand.Summarize)
|
||||
elif ConversationCommand.Summarize in conversation_commands:
|
||||
response_log = ""
|
||||
if len(file_filters) == 0:
|
||||
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
||||
if len(file_filters) == 0 and not agent_has_entries:
|
||||
response_log = "No files selected for summarization. Please add files using the section on the left."
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
elif len(file_filters) > 1:
|
||||
elif len(file_filters) > 1 and not agent_has_entries:
|
||||
response_log = "Only one file can be selected for summarization."
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
else:
|
||||
try:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||
if len(file_object) == 0:
|
||||
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
return
|
||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||
if not q:
|
||||
q = "Create a general summary of the file"
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}"
|
||||
):
|
||||
yield result
|
||||
|
||||
response = await extract_relevant_summary(
|
||||
q,
|
||||
contextual_data,
|
||||
conversation_history=meta_log,
|
||||
response_log = await generate_summary_from_files(
|
||||
q=query,
|
||||
user=user,
|
||||
file_filters=file_filters,
|
||||
meta_log=meta_log,
|
||||
subscribed=subscribed,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
send_response_func=partial(send_llm_response),
|
||||
)
|
||||
response_log = str(response)
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
except Exception as e:
|
||||
response_log = "Error summarizing file."
|
||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
response_log,
|
||||
@@ -1335,6 +1340,7 @@ async def get_chat(
|
||||
location,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
@@ -1350,8 +1356,6 @@ async def get_chat(
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
|
||||
yield result
|
||||
|
||||
online_results: Dict = dict()
|
||||
|
||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
||||
yield result
|
||||
@@ -1372,6 +1376,7 @@ async def get_chat(
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
custom_filters,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
@@ -1395,6 +1400,7 @@ async def get_chat(
|
||||
subscribed,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
@@ -1441,6 +1447,7 @@ async def get_chat(
|
||||
subscribed=subscribed,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
@@ -39,6 +40,7 @@ from khoj.database.adapters import (
|
||||
AutomationAdapters,
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
FileObjectAdapters,
|
||||
create_khoj_token,
|
||||
get_khoj_tokens,
|
||||
get_user_name,
|
||||
@@ -614,6 +616,58 @@ async def extract_relevant_summary(
|
||||
return response.strip()
|
||||
|
||||
|
||||
async def generate_summary_from_files(
|
||||
q: str,
|
||||
user: KhojUser,
|
||||
file_filters: List[str],
|
||||
meta_log: dict,
|
||||
subscribed: bool,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
send_response_func: Optional[Callable] = None,
|
||||
):
|
||||
try:
|
||||
file_object = None
|
||||
if await EntryAdapters.aagent_has_entries(agent):
|
||||
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
|
||||
if len(file_names) > 0:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names[0], agent)
|
||||
|
||||
if len(file_filters) > 0:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||
|
||||
if len(file_object) == 0:
|
||||
response_log = (
|
||||
"Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
|
||||
)
|
||||
async for result in send_response_func(response_log):
|
||||
yield result
|
||||
return
|
||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||
if not q:
|
||||
q = "Create a general summary of the file"
|
||||
async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"):
|
||||
yield result
|
||||
|
||||
response = await extract_relevant_summary(
|
||||
q,
|
||||
contextual_data,
|
||||
conversation_history=meta_log,
|
||||
subscribed=subscribed,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
)
|
||||
response_log = str(response)
|
||||
async for result in send_response_func(response_log):
|
||||
yield result
|
||||
except Exception as e:
|
||||
response_log = "Error summarizing file. Please try again, or contact support."
|
||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||
async for result in send_response_func(response_log):
|
||||
yield result
|
||||
|
||||
|
||||
async def generate_better_image_prompt(
|
||||
q: str,
|
||||
conversation_history: str,
|
||||
@@ -893,6 +947,7 @@ def generate_chat_response(
|
||||
q: str,
|
||||
meta_log: dict,
|
||||
conversation: Conversation,
|
||||
meta_research: str = "",
|
||||
compiled_references: List[Dict] = [],
|
||||
online_results: Dict[str, Dict] = {},
|
||||
inferred_queries: List[str] = [],
|
||||
@@ -910,6 +965,9 @@ def generate_chat_response(
|
||||
|
||||
metadata = {}
|
||||
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
||||
query_to_run = q
|
||||
if meta_research:
|
||||
query_to_run = f"AI Research: {meta_research} {q}"
|
||||
try:
|
||||
partial_completion = partial(
|
||||
save_to_conversation_log,
|
||||
@@ -937,7 +995,7 @@ def generate_chat_response(
|
||||
chat_response = converse_offline(
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
user_query=q,
|
||||
user_query=query_to_run,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
@@ -956,7 +1014,7 @@ def generate_chat_response(
|
||||
chat_model = conversation_config.chat_model
|
||||
chat_response = converse(
|
||||
compiled_references,
|
||||
q,
|
||||
query_to_run,
|
||||
image_url=uploaded_image_url,
|
||||
online_results=online_results,
|
||||
conversation_log=meta_log,
|
||||
@@ -977,7 +1035,7 @@ def generate_chat_response(
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
chat_response = converse_anthropic(
|
||||
compiled_references,
|
||||
q,
|
||||
query_to_run,
|
||||
online_results,
|
||||
meta_log,
|
||||
model=conversation_config.chat_model,
|
||||
@@ -994,7 +1052,7 @@ def generate_chat_response(
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
chat_response = converse_gemini(
|
||||
compiled_references,
|
||||
q,
|
||||
query_to_run,
|
||||
online_results,
|
||||
meta_log,
|
||||
model=conversation_config.chat_model,
|
||||
|
||||
261
src/khoj/routers/research.py
Normal file
261
src/khoj/routers/research.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from khoj.database.adapters import EntryAdapters
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import remove_json_codeblock
|
||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||
from khoj.routers.api import extract_references_and_questions
|
||||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
construct_chat_history,
|
||||
generate_summary_from_files,
|
||||
send_message_to_model_wrapper,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
function_calling_description_for_llm,
|
||||
timer,
|
||||
)
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InformationCollectionIteration:
|
||||
def __init__(
|
||||
self, data_source: str, query: str, context: str = None, onlineContext: str = None, result: Any = None
|
||||
):
|
||||
self.data_source = data_source
|
||||
self.query = query
|
||||
self.context = context
|
||||
self.onlineContext = onlineContext
|
||||
|
||||
|
||||
async def apick_next_tool(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
subscribed: bool,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
previous_iterations: List[InformationCollectionIteration] = None,
|
||||
):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
|
||||
"""
|
||||
|
||||
tool_options = dict()
|
||||
tool_options_str = ""
|
||||
|
||||
agent_tools = agent.input_tools if agent else []
|
||||
|
||||
for tool, description in function_calling_description_for_llm.items():
|
||||
tool_options[tool.value] = description
|
||||
if len(agent_tools) == 0 or tool.value in agent_tools:
|
||||
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
||||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
previous_iterations_history = ""
|
||||
for iteration in previous_iterations:
|
||||
iteration_data = prompts.previous_iteration.format(
|
||||
query=iteration.query,
|
||||
data_source=iteration.data_source,
|
||||
context=str(iteration.context),
|
||||
onlineContext=str(iteration.onlineContext),
|
||||
)
|
||||
|
||||
previous_iterations_history += iteration_data
|
||||
|
||||
if uploaded_image_url:
|
||||
query = f"[placeholder for user attached image]\n{query}"
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
function_planning_prompt = prompts.plan_function_execution.format(
|
||||
query=query,
|
||||
tools=tool_options_str,
|
||||
chat_history=chat_history,
|
||||
personality_context=personality_context,
|
||||
previous_iterations=previous_iterations_history,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
function_planning_prompt,
|
||||
response_type="json_object",
|
||||
subscribed=subscribed,
|
||||
)
|
||||
|
||||
try:
|
||||
response = response.strip()
|
||||
response = remove_json_codeblock(response)
|
||||
response = json.loads(response)
|
||||
suggested_data_source = response.get("data_source", None)
|
||||
suggested_query = response.get("query", None)
|
||||
|
||||
return InformationCollectionIteration(
|
||||
data_source=suggested_data_source,
|
||||
query=suggested_query,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
|
||||
return InformationCollectionIteration(
|
||||
data_source=None,
|
||||
query=None,
|
||||
)
|
||||
|
||||
|
||||
async def execute_information_collection(
|
||||
request: Request,
|
||||
user: KhojUser,
|
||||
query: str,
|
||||
conversation_id: str,
|
||||
conversation_history: dict,
|
||||
subscribed: bool,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
location: LocationData = None,
|
||||
file_filters: List[str] = [],
|
||||
):
|
||||
iteration = 0
|
||||
MAX_ITERATIONS = 2
|
||||
previous_iterations = []
|
||||
while iteration < MAX_ITERATIONS:
|
||||
online_results: Dict = dict()
|
||||
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||
this_iteration = await apick_next_tool(
|
||||
query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations
|
||||
)
|
||||
if this_iteration.data_source == ConversationCommand.Notes:
|
||||
## Extract Document References
|
||||
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
conversation_history,
|
||||
this_iteration.query,
|
||||
7,
|
||||
None,
|
||||
conversation_id,
|
||||
[ConversationCommand.Default],
|
||||
location,
|
||||
send_status_func,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
compiled_references.extend(result[0])
|
||||
inferred_queries.extend(result[1])
|
||||
defiltered_query = result[2]
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
context=str(compiled_references),
|
||||
)
|
||||
)
|
||||
|
||||
elif this_iteration.data_source == ConversationCommand.Online:
|
||||
async for result in search_online(
|
||||
this_iteration.query,
|
||||
conversation_history,
|
||||
location,
|
||||
user,
|
||||
subscribed,
|
||||
send_status_func,
|
||||
[],
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
online_results = result
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
onlineContext=online_results,
|
||||
)
|
||||
)
|
||||
|
||||
elif this_iteration.data_source == ConversationCommand.Webpage:
|
||||
async for result in read_webpages(
|
||||
this_iteration.query,
|
||||
conversation_history,
|
||||
location,
|
||||
user,
|
||||
subscribed,
|
||||
send_status_func,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
direct_web_pages = result
|
||||
|
||||
webpages = []
|
||||
for query in direct_web_pages:
|
||||
if online_results.get(query):
|
||||
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
|
||||
else:
|
||||
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
|
||||
|
||||
for webpage in direct_web_pages[query]["webpages"]:
|
||||
webpages.append(webpage["link"])
|
||||
yield send_status_func(f"**Read web pages**: {webpages}")
|
||||
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
onlineContext=online_results,
|
||||
)
|
||||
)
|
||||
|
||||
elif this_iteration.data_source == ConversationCommand.Summarize:
|
||||
response_log = ""
|
||||
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
||||
if len(file_filters) == 0 and not agent_has_entries:
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
context="No files selected for summarization.",
|
||||
)
|
||||
)
|
||||
elif len(file_filters) > 1 and not agent_has_entries:
|
||||
response_log = "Only one file can be selected for summarization."
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
context=response_log,
|
||||
)
|
||||
)
|
||||
else:
|
||||
response_log = await generate_summary_from_files(
|
||||
q=query,
|
||||
user=user,
|
||||
file_filters=file_filters,
|
||||
meta_log=conversation_history,
|
||||
subscribed=subscribed,
|
||||
send_status_func=send_status_func,
|
||||
)
|
||||
else:
|
||||
iteration = MAX_ITERATIONS
|
||||
|
||||
iteration += 1
|
||||
for completed_iter in previous_iterations:
|
||||
yield completed_iter
|
||||
@@ -345,6 +345,13 @@ tool_descriptions_for_llm = {
|
||||
ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.",
|
||||
}
|
||||
|
||||
function_calling_description_for_llm = {
|
||||
ConversationCommand.Notes: "Use this if you think the user's personal knowledge base contains relevant context.",
|
||||
ConversationCommand.Online: "Use this if you think the there's important information on the internet related to the query.",
|
||||
ConversationCommand.Webpage: "Use this if the user has provided a webpage URL or you are share of a webpage URL that will help you directly answer this query",
|
||||
ConversationCommand.Summarize: "Use this if you want to retrieve an answer that depends on reading an entire corpus.",
|
||||
}
|
||||
|
||||
mode_descriptions_for_llm = {
|
||||
ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.",
|
||||
ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.",
|
||||
|
||||
Reference in New Issue
Block a user