mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Working prototype of meta-level chain of reasoning and execution
- Create a more dynamic reasoning agent that can evaluate information and understand what it doesn't know, making moves to get that information - Lots of hacks and code that needs to be reversed later on before submission
This commit is contained in:
@@ -485,6 +485,47 @@ Khoj:
|
|||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
plan_function_execution = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
You are Khoj, an extremely smart and helpful search assistant.
|
||||||
|
{personality_context}
|
||||||
|
- You have access to a variety of data sources to help you answer the user's question
|
||||||
|
- You can use the data sources listed below to collect more relevant information, one at a time
|
||||||
|
- You are given multiple iterations to with these data sources to answer the user's question
|
||||||
|
- You are provided with additional context. If you have enough context to answer the question, then exit execution
|
||||||
|
|
||||||
|
If you already know the answer to the question, return an empty response, e.g., {{}}.
|
||||||
|
|
||||||
|
Which of the data sources listed below you would use to answer the user's question? You **only** have access to the following data sources:
|
||||||
|
|
||||||
|
{tools}
|
||||||
|
|
||||||
|
Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data source and associated query in a JSON object. Do not say anything else.
|
||||||
|
|
||||||
|
Previous Iterations:
|
||||||
|
{previous_iterations}
|
||||||
|
|
||||||
|
Response format:
|
||||||
|
{{"data_source": "<tool_name>", "query": "<your_new_query>"}}
|
||||||
|
|
||||||
|
Chat History:
|
||||||
|
{chat_history}
|
||||||
|
|
||||||
|
Q: {query}
|
||||||
|
Khoj:
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
previous_iteration = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
data_source: {data_source}
|
||||||
|
query: {query}
|
||||||
|
context: {context}
|
||||||
|
onlineContext: {onlineContext}
|
||||||
|
---
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
pick_relevant_information_collection_tools = PromptTemplate.from_template(
|
pick_relevant_information_collection_tools = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
You are Khoj, an extremely smart and helpful search assistant.
|
You are Khoj, an extremely smart and helpful search assistant.
|
||||||
|
|||||||
@@ -355,9 +355,10 @@ async def extract_references_and_questions(
|
|||||||
agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent)
|
agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not ConversationCommand.Notes in conversation_commands
|
# not ConversationCommand.Notes in conversation_commands
|
||||||
and not ConversationCommand.Default in conversation_commands
|
# and not ConversationCommand.Default in conversation_commands
|
||||||
and not agent_has_entries
|
# and not agent_has_entries
|
||||||
|
True
|
||||||
):
|
):
|
||||||
yield compiled_references, inferred_queries, q
|
yield compiled_references, inferred_queries, q
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -41,7 +41,8 @@ from khoj.routers.helpers import (
|
|||||||
aget_relevant_output_modes,
|
aget_relevant_output_modes,
|
||||||
construct_automation_created_message,
|
construct_automation_created_message,
|
||||||
create_automation,
|
create_automation,
|
||||||
extract_relevant_summary,
|
extract_relevant_info,
|
||||||
|
generate_summary_from_files,
|
||||||
get_conversation_command,
|
get_conversation_command,
|
||||||
is_query_empty,
|
is_query_empty,
|
||||||
is_ready_to_chat,
|
is_ready_to_chat,
|
||||||
@@ -49,6 +50,10 @@ from khoj.routers.helpers import (
|
|||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
validate_conversation_config,
|
validate_conversation_config,
|
||||||
)
|
)
|
||||||
|
from khoj.routers.research import (
|
||||||
|
InformationCollectionIteration,
|
||||||
|
execute_information_collection,
|
||||||
|
)
|
||||||
from khoj.routers.storage import upload_image_to_bucket
|
from khoj.routers.storage import upload_image_to_bucket
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
@@ -689,7 +694,46 @@ async def chat(
|
|||||||
meta_log = conversation.conversation_log
|
meta_log = conversation.conversation_log
|
||||||
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
||||||
|
|
||||||
|
pending_research = True
|
||||||
|
|
||||||
|
researched_results = ""
|
||||||
|
online_results: Dict = dict()
|
||||||
|
## Extract Document References
|
||||||
|
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||||
|
async for research_result in execute_information_collection(
|
||||||
|
request=request,
|
||||||
|
user=user,
|
||||||
|
query=q,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
conversation_history=meta_log,
|
||||||
|
subscribed=subscribed,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
agent=agent,
|
||||||
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
|
location=location,
|
||||||
|
file_filters=conversation.file_filters if conversation else [],
|
||||||
|
):
|
||||||
|
if type(research_result) == InformationCollectionIteration:
|
||||||
|
pending_research = False
|
||||||
|
if research_result.onlineContext:
|
||||||
|
researched_results += str(research_result.onlineContext)
|
||||||
|
online_results.update(research_result.onlineContext)
|
||||||
|
|
||||||
|
if research_result.context:
|
||||||
|
researched_results += str(research_result.context)
|
||||||
|
compiled_references.extend(research_result.context)
|
||||||
|
|
||||||
|
else:
|
||||||
|
yield research_result
|
||||||
|
|
||||||
|
researched_results = await extract_relevant_info(q, researched_results, agent)
|
||||||
|
|
||||||
|
logger.info(f"Researched Results: {researched_results}")
|
||||||
|
|
||||||
|
pending_research = False
|
||||||
|
|
||||||
conversation_commands = await aget_relevant_information_sources(
|
conversation_commands = await aget_relevant_information_sources(
|
||||||
q,
|
q,
|
||||||
meta_log,
|
meta_log,
|
||||||
@@ -724,9 +768,11 @@ async def chat(
|
|||||||
and not used_slash_summarize
|
and not used_slash_summarize
|
||||||
# but we can't actually summarize
|
# but we can't actually summarize
|
||||||
and len(file_filters) != 1
|
and len(file_filters) != 1
|
||||||
|
# not pending research
|
||||||
|
and not pending_research
|
||||||
):
|
):
|
||||||
conversation_commands.remove(ConversationCommand.Summarize)
|
conversation_commands.remove(ConversationCommand.Summarize)
|
||||||
elif ConversationCommand.Summarize in conversation_commands:
|
elif ConversationCommand.Summarize in conversation_commands and pending_research:
|
||||||
response_log = ""
|
response_log = ""
|
||||||
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
||||||
if len(file_filters) == 0 and not agent_has_entries:
|
if len(file_filters) == 0 and not agent_has_entries:
|
||||||
@@ -738,47 +784,15 @@ async def chat(
|
|||||||
async for result in send_llm_response(response_log):
|
async for result in send_llm_response(response_log):
|
||||||
yield result
|
yield result
|
||||||
else:
|
else:
|
||||||
try:
|
response_log = await generate_summary_from_files(
|
||||||
file_object = None
|
q=query,
|
||||||
if await EntryAdapters.aagent_has_entries(agent):
|
user=user,
|
||||||
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
|
file_filters=file_filters,
|
||||||
if len(file_names) > 0:
|
meta_log=meta_log,
|
||||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(
|
|
||||||
None, file_names[0], agent
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(file_filters) > 0:
|
|
||||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
|
||||||
|
|
||||||
if len(file_object) == 0:
|
|
||||||
response_log = "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
|
|
||||||
async for result in send_llm_response(response_log):
|
|
||||||
yield result
|
|
||||||
return
|
|
||||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
|
||||||
if not q:
|
|
||||||
q = "Create a general summary of the file"
|
|
||||||
async for result in send_event(
|
|
||||||
ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}"
|
|
||||||
):
|
|
||||||
yield result
|
|
||||||
|
|
||||||
response = await extract_relevant_summary(
|
|
||||||
q,
|
|
||||||
contextual_data,
|
|
||||||
conversation_history=meta_log,
|
|
||||||
subscribed=subscribed,
|
subscribed=subscribed,
|
||||||
uploaded_image_url=uploaded_image_url,
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
agent=agent,
|
send_response_func=partial(send_llm_response),
|
||||||
)
|
)
|
||||||
response_log = str(response)
|
|
||||||
async for result in send_llm_response(response_log):
|
|
||||||
yield result
|
|
||||||
except Exception as e:
|
|
||||||
response_log = "Error summarizing file. Please try again, or contact support."
|
|
||||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
|
||||||
async for result in send_llm_response(response_log):
|
|
||||||
yield result
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
q,
|
q,
|
||||||
response_log,
|
response_log,
|
||||||
@@ -838,8 +852,6 @@ async def chat(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Gather Context
|
# Gather Context
|
||||||
## Extract Document References
|
|
||||||
compiled_references, inferred_queries, defiltered_query = [], [], None
|
|
||||||
async for result in extract_references_and_questions(
|
async for result in extract_references_and_questions(
|
||||||
request,
|
request,
|
||||||
meta_log,
|
meta_log,
|
||||||
@@ -867,8 +879,6 @@ async def chat(
|
|||||||
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
|
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
online_results: Dict = dict()
|
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||||
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
||||||
yield result
|
yield result
|
||||||
@@ -878,7 +888,7 @@ async def chat(
|
|||||||
conversation_commands.remove(ConversationCommand.Notes)
|
conversation_commands.remove(ConversationCommand.Notes)
|
||||||
|
|
||||||
## Gather Online References
|
## Gather Online References
|
||||||
if ConversationCommand.Online in conversation_commands:
|
if ConversationCommand.Online in conversation_commands and pending_research:
|
||||||
try:
|
try:
|
||||||
async for result in search_online(
|
async for result in search_online(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
@@ -903,7 +913,7 @@ async def chat(
|
|||||||
return
|
return
|
||||||
|
|
||||||
## Gather Webpage References
|
## Gather Webpage References
|
||||||
if ConversationCommand.Webpage in conversation_commands:
|
if ConversationCommand.Webpage in conversation_commands and pending_research:
|
||||||
try:
|
try:
|
||||||
async for result in read_webpages(
|
async for result in read_webpages(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
@@ -1008,6 +1018,7 @@ async def chat(
|
|||||||
defiltered_query,
|
defiltered_query,
|
||||||
meta_log,
|
meta_log,
|
||||||
conversation,
|
conversation,
|
||||||
|
researched_results,
|
||||||
compiled_references,
|
compiled_references,
|
||||||
online_results,
|
online_results,
|
||||||
inferred_queries,
|
inferred_queries,
|
||||||
@@ -1051,36 +1062,32 @@ async def chat(
|
|||||||
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
# Deprecated API. Remove by end of September 2024
|
# @api_chat.post("")
|
||||||
@api_chat.get("")
|
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def get_chat(
|
async def old_chat(
|
||||||
request: Request,
|
request: Request,
|
||||||
common: CommonQueryParams,
|
common: CommonQueryParams,
|
||||||
q: str,
|
body: ChatRequestBody,
|
||||||
n: int = 7,
|
|
||||||
d: float = None,
|
|
||||||
stream: Optional[bool] = False,
|
|
||||||
title: Optional[str] = None,
|
|
||||||
conversation_id: Optional[str] = None,
|
|
||||||
city: Optional[str] = None,
|
|
||||||
region: Optional[str] = None,
|
|
||||||
country: Optional[str] = None,
|
|
||||||
timezone: Optional[str] = None,
|
|
||||||
image: Optional[str] = None,
|
|
||||||
rate_limiter_per_minute=Depends(
|
rate_limiter_per_minute=Depends(
|
||||||
ApiUserRateLimiter(requests=60, subscribed_requests=60, window=60, slug="chat_minute")
|
ApiUserRateLimiter(requests=60, subscribed_requests=200, window=60, slug="chat_minute")
|
||||||
),
|
),
|
||||||
rate_limiter_per_day=Depends(
|
rate_limiter_per_day=Depends(
|
||||||
ApiUserRateLimiter(requests=600, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
# Issue a deprecation warning
|
# Access the parameters from the body
|
||||||
warnings.warn(
|
q = body.q
|
||||||
"The 'get_chat' API endpoint is deprecated. It will be removed by the end of September 2024.",
|
n = body.n
|
||||||
DeprecationWarning,
|
d = body.d
|
||||||
stacklevel=2,
|
stream = body.stream
|
||||||
)
|
title = body.title
|
||||||
|
conversation_id = body.conversation_id
|
||||||
|
city = body.city
|
||||||
|
region = body.region
|
||||||
|
country = body.country or get_country_name_from_timezone(body.timezone)
|
||||||
|
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
||||||
|
timezone = body.timezone
|
||||||
|
image = body.image
|
||||||
|
|
||||||
async def event_generator(q: str, image: str):
|
async def event_generator(q: str, image: str):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
@@ -1108,7 +1115,7 @@ async def get_chat(
|
|||||||
nonlocal connection_alive, ttft
|
nonlocal connection_alive, ttft
|
||||||
if not connection_alive or await request.is_disconnected():
|
if not connection_alive or await request.is_disconnected():
|
||||||
connection_alive = False
|
connection_alive = False
|
||||||
logger.warn(f"User {user} disconnected from {common.client} client")
|
logger.warning(f"User {user} disconnected from {common.client} client")
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
if event_type == ChatEvent.END_LLM_RESPONSE:
|
if event_type == ChatEvent.END_LLM_RESPONSE:
|
||||||
@@ -1164,21 +1171,34 @@ async def get_chat(
|
|||||||
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||||
|
|
||||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||||
user, client_application=request.user.client_app, conversation_id=conversation_id, title=title
|
user,
|
||||||
|
client_application=request.user.client_app,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
title=title,
|
||||||
|
create_new=body.create_new,
|
||||||
)
|
)
|
||||||
if not conversation:
|
if not conversation:
|
||||||
async for result in send_llm_response(f"Conversation {conversation_id} not found"):
|
async for result in send_llm_response(f"Conversation {conversation_id} not found"):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
conversation_id = conversation.id
|
conversation_id = conversation.id
|
||||||
agent = conversation.agent if conversation.agent else None
|
|
||||||
|
agent: Agent | None = None
|
||||||
|
default_agent = await AgentAdapters.aget_default_agent()
|
||||||
|
if conversation.agent and conversation.agent != default_agent:
|
||||||
|
agent = conversation.agent
|
||||||
|
|
||||||
|
if not conversation.agent:
|
||||||
|
conversation.agent = default_agent
|
||||||
|
await conversation.asave()
|
||||||
|
agent = default_agent
|
||||||
|
|
||||||
await is_ready_to_chat(user)
|
await is_ready_to_chat(user)
|
||||||
|
|
||||||
user_name = await aget_user_name(user)
|
user_name = await aget_user_name(user)
|
||||||
location = None
|
location = None
|
||||||
if city or region or country:
|
if city or region or country or country_code:
|
||||||
location = LocationData(city=city, region=region, country=country)
|
location = LocationData(city=city, region=region, country=country, country_code=country_code)
|
||||||
|
|
||||||
if is_query_empty(q):
|
if is_query_empty(q):
|
||||||
async for result in send_llm_response("Please ask your query to get started."):
|
async for result in send_llm_response("Please ask your query to get started."):
|
||||||
@@ -1192,7 +1212,12 @@ async def get_chat(
|
|||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||||
conversation_commands = await aget_relevant_information_sources(
|
conversation_commands = await aget_relevant_information_sources(
|
||||||
q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url
|
q,
|
||||||
|
meta_log,
|
||||||
|
is_automated_task,
|
||||||
|
subscribed=subscribed,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
agent=agent,
|
||||||
)
|
)
|
||||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||||
async for result in send_event(
|
async for result in send_event(
|
||||||
@@ -1200,7 +1225,7 @@ async def get_chat(
|
|||||||
):
|
):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url)
|
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent)
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||||
yield result
|
yield result
|
||||||
if mode not in conversation_commands:
|
if mode not in conversation_commands:
|
||||||
@@ -1224,45 +1249,25 @@ async def get_chat(
|
|||||||
conversation_commands.remove(ConversationCommand.Summarize)
|
conversation_commands.remove(ConversationCommand.Summarize)
|
||||||
elif ConversationCommand.Summarize in conversation_commands:
|
elif ConversationCommand.Summarize in conversation_commands:
|
||||||
response_log = ""
|
response_log = ""
|
||||||
if len(file_filters) == 0:
|
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
||||||
|
if len(file_filters) == 0 and not agent_has_entries:
|
||||||
response_log = "No files selected for summarization. Please add files using the section on the left."
|
response_log = "No files selected for summarization. Please add files using the section on the left."
|
||||||
async for result in send_llm_response(response_log):
|
async for result in send_llm_response(response_log):
|
||||||
yield result
|
yield result
|
||||||
elif len(file_filters) > 1:
|
elif len(file_filters) > 1 and not agent_has_entries:
|
||||||
response_log = "Only one file can be selected for summarization."
|
response_log = "Only one file can be selected for summarization."
|
||||||
async for result in send_llm_response(response_log):
|
async for result in send_llm_response(response_log):
|
||||||
yield result
|
yield result
|
||||||
else:
|
else:
|
||||||
try:
|
response_log = await generate_summary_from_files(
|
||||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
q=query,
|
||||||
if len(file_object) == 0:
|
user=user,
|
||||||
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
|
file_filters=file_filters,
|
||||||
async for result in send_llm_response(response_log):
|
meta_log=meta_log,
|
||||||
yield result
|
|
||||||
return
|
|
||||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
|
||||||
if not q:
|
|
||||||
q = "Create a general summary of the file"
|
|
||||||
async for result in send_event(
|
|
||||||
ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}"
|
|
||||||
):
|
|
||||||
yield result
|
|
||||||
|
|
||||||
response = await extract_relevant_summary(
|
|
||||||
q,
|
|
||||||
contextual_data,
|
|
||||||
conversation_history=meta_log,
|
|
||||||
subscribed=subscribed,
|
subscribed=subscribed,
|
||||||
uploaded_image_url=uploaded_image_url,
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
|
send_response_func=partial(send_llm_response),
|
||||||
)
|
)
|
||||||
response_log = str(response)
|
|
||||||
async for result in send_llm_response(response_log):
|
|
||||||
yield result
|
|
||||||
except Exception as e:
|
|
||||||
response_log = "Error summarizing file."
|
|
||||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
|
||||||
async for result in send_llm_response(response_log):
|
|
||||||
yield result
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
q,
|
q,
|
||||||
response_log,
|
response_log,
|
||||||
@@ -1335,6 +1340,7 @@ async def get_chat(
|
|||||||
location,
|
location,
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
agent=agent,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
@@ -1350,8 +1356,6 @@ async def get_chat(
|
|||||||
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
|
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
online_results: Dict = dict()
|
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||||
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
||||||
yield result
|
yield result
|
||||||
@@ -1372,6 +1376,7 @@ async def get_chat(
|
|||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
custom_filters,
|
custom_filters,
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
agent=agent,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
@@ -1395,6 +1400,7 @@ async def get_chat(
|
|||||||
subscribed,
|
subscribed,
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
agent=agent,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
@@ -1441,6 +1447,7 @@ async def get_chat(
|
|||||||
subscribed=subscribed,
|
subscribed=subscribed,
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
agent=agent,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from typing import (
|
|||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
@@ -39,6 +40,7 @@ from khoj.database.adapters import (
|
|||||||
AutomationAdapters,
|
AutomationAdapters,
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
|
FileObjectAdapters,
|
||||||
create_khoj_token,
|
create_khoj_token,
|
||||||
get_khoj_tokens,
|
get_khoj_tokens,
|
||||||
get_user_name,
|
get_user_name,
|
||||||
@@ -614,6 +616,58 @@ async def extract_relevant_summary(
|
|||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_summary_from_files(
|
||||||
|
q: str,
|
||||||
|
user: KhojUser,
|
||||||
|
file_filters: List[str],
|
||||||
|
meta_log: dict,
|
||||||
|
subscribed: bool,
|
||||||
|
uploaded_image_url: str = None,
|
||||||
|
agent: Agent = None,
|
||||||
|
send_status_func: Optional[Callable] = None,
|
||||||
|
send_response_func: Optional[Callable] = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
file_object = None
|
||||||
|
if await EntryAdapters.aagent_has_entries(agent):
|
||||||
|
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
|
||||||
|
if len(file_names) > 0:
|
||||||
|
file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names[0], agent)
|
||||||
|
|
||||||
|
if len(file_filters) > 0:
|
||||||
|
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||||
|
|
||||||
|
if len(file_object) == 0:
|
||||||
|
response_log = (
|
||||||
|
"Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
|
||||||
|
)
|
||||||
|
async for result in send_response_func(response_log):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||||
|
if not q:
|
||||||
|
q = "Create a general summary of the file"
|
||||||
|
async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
response = await extract_relevant_summary(
|
||||||
|
q,
|
||||||
|
contextual_data,
|
||||||
|
conversation_history=meta_log,
|
||||||
|
subscribed=subscribed,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
response_log = str(response)
|
||||||
|
async for result in send_response_func(response_log):
|
||||||
|
yield result
|
||||||
|
except Exception as e:
|
||||||
|
response_log = "Error summarizing file. Please try again, or contact support."
|
||||||
|
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||||
|
async for result in send_response_func(response_log):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
|
||||||
async def generate_better_image_prompt(
|
async def generate_better_image_prompt(
|
||||||
q: str,
|
q: str,
|
||||||
conversation_history: str,
|
conversation_history: str,
|
||||||
@@ -893,6 +947,7 @@ def generate_chat_response(
|
|||||||
q: str,
|
q: str,
|
||||||
meta_log: dict,
|
meta_log: dict,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
|
meta_research: str = "",
|
||||||
compiled_references: List[Dict] = [],
|
compiled_references: List[Dict] = [],
|
||||||
online_results: Dict[str, Dict] = {},
|
online_results: Dict[str, Dict] = {},
|
||||||
inferred_queries: List[str] = [],
|
inferred_queries: List[str] = [],
|
||||||
@@ -910,6 +965,9 @@ def generate_chat_response(
|
|||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
||||||
|
query_to_run = q
|
||||||
|
if meta_research:
|
||||||
|
query_to_run = f"AI Research: {meta_research} {q}"
|
||||||
try:
|
try:
|
||||||
partial_completion = partial(
|
partial_completion = partial(
|
||||||
save_to_conversation_log,
|
save_to_conversation_log,
|
||||||
@@ -937,7 +995,7 @@ def generate_chat_response(
|
|||||||
chat_response = converse_offline(
|
chat_response = converse_offline(
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
user_query=q,
|
user_query=query_to_run,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
@@ -956,7 +1014,7 @@ def generate_chat_response(
|
|||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
chat_response = converse(
|
chat_response = converse(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
q,
|
query_to_run,
|
||||||
image_url=uploaded_image_url,
|
image_url=uploaded_image_url,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
@@ -977,7 +1035,7 @@ def generate_chat_response(
|
|||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
chat_response = converse_anthropic(
|
chat_response = converse_anthropic(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
q,
|
query_to_run,
|
||||||
online_results,
|
online_results,
|
||||||
meta_log,
|
meta_log,
|
||||||
model=conversation_config.chat_model,
|
model=conversation_config.chat_model,
|
||||||
@@ -994,7 +1052,7 @@ def generate_chat_response(
|
|||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
chat_response = converse_gemini(
|
chat_response = converse_gemini(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
q,
|
query_to_run,
|
||||||
online_results,
|
online_results,
|
||||||
meta_log,
|
meta_log,
|
||||||
model=conversation_config.chat_model,
|
model=conversation_config.chat_model,
|
||||||
|
|||||||
261
src/khoj/routers/research.py
Normal file
261
src/khoj/routers/research.py
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from khoj.database.adapters import EntryAdapters
|
||||||
|
from khoj.database.models import Agent, KhojUser
|
||||||
|
from khoj.processor.conversation import prompts
|
||||||
|
from khoj.processor.conversation.utils import remove_json_codeblock
|
||||||
|
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||||
|
from khoj.routers.api import extract_references_and_questions
|
||||||
|
from khoj.routers.helpers import (
|
||||||
|
ChatEvent,
|
||||||
|
construct_chat_history,
|
||||||
|
generate_summary_from_files,
|
||||||
|
send_message_to_model_wrapper,
|
||||||
|
)
|
||||||
|
from khoj.utils.helpers import (
|
||||||
|
ConversationCommand,
|
||||||
|
function_calling_description_for_llm,
|
||||||
|
timer,
|
||||||
|
)
|
||||||
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InformationCollectionIteration:
|
||||||
|
def __init__(
|
||||||
|
self, data_source: str, query: str, context: str = None, onlineContext: str = None, result: Any = None
|
||||||
|
):
|
||||||
|
self.data_source = data_source
|
||||||
|
self.query = query
|
||||||
|
self.context = context
|
||||||
|
self.onlineContext = onlineContext
|
||||||
|
|
||||||
|
|
||||||
|
async def apick_next_tool(
|
||||||
|
query: str,
|
||||||
|
conversation_history: dict,
|
||||||
|
subscribed: bool,
|
||||||
|
uploaded_image_url: str = None,
|
||||||
|
agent: Agent = None,
|
||||||
|
previous_iterations: List[InformationCollectionIteration] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool_options = dict()
|
||||||
|
tool_options_str = ""
|
||||||
|
|
||||||
|
agent_tools = agent.input_tools if agent else []
|
||||||
|
|
||||||
|
for tool, description in function_calling_description_for_llm.items():
|
||||||
|
tool_options[tool.value] = description
|
||||||
|
if len(agent_tools) == 0 or tool.value in agent_tools:
|
||||||
|
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
||||||
|
|
||||||
|
chat_history = construct_chat_history(conversation_history)
|
||||||
|
|
||||||
|
previous_iterations_history = ""
|
||||||
|
for iteration in previous_iterations:
|
||||||
|
iteration_data = prompts.previous_iteration.format(
|
||||||
|
query=iteration.query,
|
||||||
|
data_source=iteration.data_source,
|
||||||
|
context=str(iteration.context),
|
||||||
|
onlineContext=str(iteration.onlineContext),
|
||||||
|
)
|
||||||
|
|
||||||
|
previous_iterations_history += iteration_data
|
||||||
|
|
||||||
|
if uploaded_image_url:
|
||||||
|
query = f"[placeholder for user attached image]\n{query}"
|
||||||
|
|
||||||
|
personality_context = (
|
||||||
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
function_planning_prompt = prompts.plan_function_execution.format(
|
||||||
|
query=query,
|
||||||
|
tools=tool_options_str,
|
||||||
|
chat_history=chat_history,
|
||||||
|
personality_context=personality_context,
|
||||||
|
previous_iterations=previous_iterations_history,
|
||||||
|
)
|
||||||
|
|
||||||
|
with timer("Chat actor: Infer information sources to refer", logger):
|
||||||
|
response = await send_message_to_model_wrapper(
|
||||||
|
function_planning_prompt,
|
||||||
|
response_type="json_object",
|
||||||
|
subscribed=subscribed,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = response.strip()
|
||||||
|
response = remove_json_codeblock(response)
|
||||||
|
response = json.loads(response)
|
||||||
|
suggested_data_source = response.get("data_source", None)
|
||||||
|
suggested_query = response.get("query", None)
|
||||||
|
|
||||||
|
return InformationCollectionIteration(
|
||||||
|
data_source=suggested_data_source,
|
||||||
|
query=suggested_query,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
|
||||||
|
return InformationCollectionIteration(
|
||||||
|
data_source=None,
|
||||||
|
query=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_information_collection(
|
||||||
|
request: Request,
|
||||||
|
user: KhojUser,
|
||||||
|
query: str,
|
||||||
|
conversation_id: str,
|
||||||
|
conversation_history: dict,
|
||||||
|
subscribed: bool,
|
||||||
|
uploaded_image_url: str = None,
|
||||||
|
agent: Agent = None,
|
||||||
|
send_status_func: Optional[Callable] = None,
|
||||||
|
location: LocationData = None,
|
||||||
|
file_filters: List[str] = [],
|
||||||
|
):
|
||||||
|
iteration = 0
|
||||||
|
MAX_ITERATIONS = 2
|
||||||
|
previous_iterations = []
|
||||||
|
while iteration < MAX_ITERATIONS:
|
||||||
|
online_results: Dict = dict()
|
||||||
|
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||||
|
this_iteration = await apick_next_tool(
|
||||||
|
query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations
|
||||||
|
)
|
||||||
|
if this_iteration.data_source == ConversationCommand.Notes:
|
||||||
|
## Extract Document References
|
||||||
|
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||||
|
async for result in extract_references_and_questions(
|
||||||
|
request,
|
||||||
|
conversation_history,
|
||||||
|
this_iteration.query,
|
||||||
|
7,
|
||||||
|
None,
|
||||||
|
conversation_id,
|
||||||
|
[ConversationCommand.Default],
|
||||||
|
location,
|
||||||
|
send_status_func,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
agent=agent,
|
||||||
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
compiled_references.extend(result[0])
|
||||||
|
inferred_queries.extend(result[1])
|
||||||
|
defiltered_query = result[2]
|
||||||
|
previous_iterations.append(
|
||||||
|
InformationCollectionIteration(
|
||||||
|
data_source=this_iteration.data_source,
|
||||||
|
query=this_iteration.query,
|
||||||
|
context=str(compiled_references),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif this_iteration.data_source == ConversationCommand.Online:
|
||||||
|
async for result in search_online(
|
||||||
|
this_iteration.query,
|
||||||
|
conversation_history,
|
||||||
|
location,
|
||||||
|
user,
|
||||||
|
subscribed,
|
||||||
|
send_status_func,
|
||||||
|
[],
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
agent=agent,
|
||||||
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
online_results = result
|
||||||
|
previous_iterations.append(
|
||||||
|
InformationCollectionIteration(
|
||||||
|
data_source=this_iteration.data_source,
|
||||||
|
query=this_iteration.query,
|
||||||
|
onlineContext=online_results,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif this_iteration.data_source == ConversationCommand.Webpage:
|
||||||
|
async for result in read_webpages(
|
||||||
|
this_iteration.query,
|
||||||
|
conversation_history,
|
||||||
|
location,
|
||||||
|
user,
|
||||||
|
subscribed,
|
||||||
|
send_status_func,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
agent=agent,
|
||||||
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
direct_web_pages = result
|
||||||
|
|
||||||
|
webpages = []
|
||||||
|
for query in direct_web_pages:
|
||||||
|
if online_results.get(query):
|
||||||
|
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
|
||||||
|
else:
|
||||||
|
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
|
||||||
|
|
||||||
|
for webpage in direct_web_pages[query]["webpages"]:
|
||||||
|
webpages.append(webpage["link"])
|
||||||
|
yield send_status_func(f"**Read web pages**: {webpages}")
|
||||||
|
|
||||||
|
previous_iterations.append(
|
||||||
|
InformationCollectionIteration(
|
||||||
|
data_source=this_iteration.data_source,
|
||||||
|
query=this_iteration.query,
|
||||||
|
onlineContext=online_results,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif this_iteration.data_source == ConversationCommand.Summarize:
|
||||||
|
response_log = ""
|
||||||
|
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
||||||
|
if len(file_filters) == 0 and not agent_has_entries:
|
||||||
|
previous_iterations.append(
|
||||||
|
InformationCollectionIteration(
|
||||||
|
data_source=this_iteration.data_source,
|
||||||
|
query=this_iteration.query,
|
||||||
|
context="No files selected for summarization.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif len(file_filters) > 1 and not agent_has_entries:
|
||||||
|
response_log = "Only one file can be selected for summarization."
|
||||||
|
previous_iterations.append(
|
||||||
|
InformationCollectionIteration(
|
||||||
|
data_source=this_iteration.data_source,
|
||||||
|
query=this_iteration.query,
|
||||||
|
context=response_log,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response_log = await generate_summary_from_files(
|
||||||
|
q=query,
|
||||||
|
user=user,
|
||||||
|
file_filters=file_filters,
|
||||||
|
meta_log=conversation_history,
|
||||||
|
subscribed=subscribed,
|
||||||
|
send_status_func=send_status_func,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
iteration = MAX_ITERATIONS
|
||||||
|
|
||||||
|
iteration += 1
|
||||||
|
for completed_iter in previous_iterations:
|
||||||
|
yield completed_iter
|
||||||
@@ -345,6 +345,13 @@ tool_descriptions_for_llm = {
|
|||||||
ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.",
|
ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function_calling_description_for_llm = {
|
||||||
|
ConversationCommand.Notes: "Use this if you think the user's personal knowledge base contains relevant context.",
|
||||||
|
ConversationCommand.Online: "Use this if you think the there's important information on the internet related to the query.",
|
||||||
|
ConversationCommand.Webpage: "Use this if the user has provided a webpage URL or you are share of a webpage URL that will help you directly answer this query",
|
||||||
|
ConversationCommand.Summarize: "Use this if you want to retrieve an answer that depends on reading an entire corpus.",
|
||||||
|
}
|
||||||
|
|
||||||
mode_descriptions_for_llm = {
|
mode_descriptions_for_llm = {
|
||||||
ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.",
|
ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.",
|
||||||
ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.",
|
ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.",
|
||||||
|
|||||||
Reference in New Issue
Block a user