Merge branch 'master' into features/advanced-reasoning

This commit is contained in:
Debanjum Singh Solanky
2024-10-15 01:27:36 -07:00
9 changed files with 693 additions and 515 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -696,10 +696,12 @@ class AgentAdapters:
files: List[str], files: List[str],
input_tools: List[str], input_tools: List[str],
output_modes: List[str], output_modes: List[str],
slug: Optional[str] = None,
): ):
chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst() chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst()
agent, created = await Agent.objects.filter(name=name, creator=user).aupdate_or_create( # Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug
agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create(
defaults={ defaults={
"name": name, "name": name,
"creator": user, "creator": user,

View File

@@ -114,6 +114,7 @@ class CrossEncoderModel:
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}} payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
response = requests.post(target_url, json=payload, headers=headers) response = requests.post(target_url, json=payload, headers=headers)
response.raise_for_status()
return response.json()["scores"] return response.json()["scores"]
cross_inp = [[query, hit.additional[key]] for hit in hits] cross_inp = [[query, hit.additional[key]] for hit in hits]

View File

@@ -143,7 +143,6 @@ async def read_webpages(
conversation_history: dict, conversation_history: dict,
location: LocationData, location: LocationData,
user: KhojUser, user: KhojUser,
subscribed: bool = False,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
uploaded_image_url: str = None, uploaded_image_url: str = None,
agent: Agent = None, agent: Agent = None,

View File

@@ -35,6 +35,7 @@ class ModifyAgentBody(BaseModel):
files: Optional[List[str]] = [] files: Optional[List[str]] = []
input_tools: Optional[List[str]] = [] input_tools: Optional[List[str]] = []
output_modes: Optional[List[str]] = [] output_modes: Optional[List[str]] = []
slug: Optional[str] = None
@api_agents.get("", response_class=Response) @api_agents.get("", response_class=Response)
@@ -192,6 +193,7 @@ async def create_agent(
body.files, body.files,
body.input_tools, body.input_tools,
body.output_modes, body.output_modes,
body.slug,
) )
agents_packet = { agents_packet = {
@@ -233,7 +235,7 @@ async def update_agent(
status_code=400, status_code=400,
) )
selected_agent = await AgentAdapters.aget_agent_by_name(body.name, user) selected_agent = await AgentAdapters.aget_agent_by_slug(body.slug, user)
if not selected_agent: if not selected_agent:
return Response( return Response(
@@ -253,6 +255,7 @@ async def update_agent(
body.files, body.files,
body.input_tools, body.input_tools,
body.output_modes, body.output_modes,
body.slug,
) )
agents_packet = { agents_packet = {

View File

@@ -213,7 +213,7 @@ def chat_history(
agent_metadata = None agent_metadata = None
if conversation.agent: if conversation.agent:
if conversation.agent.privacy_level == Agent.PrivacyLevel.PRIVATE: if conversation.agent.privacy_level == Agent.PrivacyLevel.PRIVATE and conversation.agent.creator != user:
conversation.agent = None conversation.agent = None
else: else:
agent_metadata = { agent_metadata = {
@@ -853,27 +853,36 @@ async def chat(
return return
# # Gather Context # # Gather Context
# async for result in extract_references_and_questions( # # Extract Document References
# request, # try:
# meta_log, # async for result in extract_references_and_questions(
# q, # request,
# (n or 7), # meta_log,
# d, # q,
# conversation_id, # (n or 7),
# conversation_commands, # d,
# location, # conversation_id,
# partial(send_event, ChatEvent.STATUS), # conversation_commands,
# uploaded_image_url=uploaded_image_url, # location,
# agent=agent, # partial(send_event, ChatEvent.STATUS),
# ): # uploaded_image_url=uploaded_image_url,
# if isinstance(result, dict) and ChatEvent.STATUS in result: # agent=agent,
# yield result[ChatEvent.STATUS] # ):
# else: # if isinstance(result, dict) and ChatEvent.STATUS in result:
# compiled_references.extend(result[0]) # yield result[ChatEvent.STATUS]
# inferred_queries.extend(result[1]) # else:
# defiltered_query = result[2] # compiled_references.extend(result[0])
# inferred_queries.extend(result[1])
# if not is_none_or_empty(compiled_references): # defiltered_query = result[2]
# except Exception as e:
# error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
# logger.warning(error_message)
# async for result in send_event(
# ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
# ):
# yield result
#
# # if not is_none_or_empty(compiled_references):
# try: # try:
# headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) # headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
# # Strip only leading # from headings # # Strip only leading # from headings
@@ -910,12 +919,13 @@ async def chat(
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
else: else:
online_results = result online_results = result
except ValueError as e: except Exception as e:
error_message = f"Error searching online: {e}. Attempting to respond without online results" error_message = f"Error searching online: {e}. Attempting to respond without online results"
logger.warning(error_message) logger.warning(error_message)
async for result in send_llm_response(error_message): async for result in send_event(
ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
):
yield result yield result
return
## Gather Webpage References ## Gather Webpage References
if ConversationCommand.Webpage in conversation_commands and pending_research: if ConversationCommand.Webpage in conversation_commands and pending_research:
@@ -925,7 +935,6 @@ async def chat(
meta_log, meta_log,
location, location,
user, user,
subscribed,
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url, uploaded_image_url=uploaded_image_url,
agent=agent, agent=agent,
@@ -945,11 +954,15 @@ async def chat(
webpages.append(webpage["link"]) webpages.append(webpage["link"])
async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"): async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
yield result yield result
except ValueError as e: except Exception as e:
logger.warning( logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results", f"Error reading webpages: {e}. Attempting to respond without webpage results",
exc_info=True, exc_info=True,
) )
async for result in send_event(
ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
):
yield result
## Gather Code Results ## Gather Code Results
if ConversationCommand.Code in conversation_commands and pending_research: if ConversationCommand.Code in conversation_commands and pending_research:

View File

@@ -345,13 +345,13 @@ async def aget_relevant_information_sources(
final_response = [ConversationCommand.Default] final_response = [ConversationCommand.Default]
else: else:
final_response = [ConversationCommand.General] final_response = [ConversationCommand.General]
return final_response except Exception:
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}") logger.error(f"Invalid response for determining relevant tools: {response}")
if len(agent_tools) == 0: if len(agent_tools) == 0:
final_response = [ConversationCommand.Default] final_response = [ConversationCommand.Default]
else: else:
final_response = agent_tools final_response = agent_tools
return final_response
async def aget_relevant_output_modes( async def aget_relevant_output_modes(

View File

@@ -227,7 +227,6 @@ async def execute_information_collection(
conversation_history, conversation_history,
location, location,
user, user,
subscribed,
send_status_func, send_status_func,
uploaded_image_url=uploaded_image_url, uploaded_image_url=uploaded_image_url,
agent=agent, agent=agent,

View File

@@ -3,6 +3,7 @@ import math
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Type, Union from typing import List, Optional, Tuple, Type, Union
import requests
import torch import torch
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from sentence_transformers import util from sentence_transformers import util
@@ -231,8 +232,12 @@ def setup(
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]: def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
"""Score all retrieved entries using the cross-encoder""" """Score all retrieved entries using the cross-encoder"""
with timer("Cross-Encoder Predict Time", logger, state.device): try:
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits) with timer("Cross-Encoder Predict Time", logger, state.device):
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
except requests.exceptions.HTTPError as e:
logger.error(f"Failed to rerank documents using the inference endpoint. Error: {e}.", exc_info=True)
cross_scores = [0.0] * len(hits)
# Convert cross-encoder scores to distances and pass in hits for reranking # Convert cross-encoder scores to distances and pass in hits for reranking
for idx in range(len(cross_scores)): for idx in range(len(cross_scores)):