mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 05:39:12 +00:00
Merge branch 'master' into features/advanced-reasoning
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -696,10 +696,12 @@ class AgentAdapters:
|
|||||||
files: List[str],
|
files: List[str],
|
||||||
input_tools: List[str],
|
input_tools: List[str],
|
||||||
output_modes: List[str],
|
output_modes: List[str],
|
||||||
|
slug: Optional[str] = None,
|
||||||
):
|
):
|
||||||
chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst()
|
chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst()
|
||||||
|
|
||||||
agent, created = await Agent.objects.filter(name=name, creator=user).aupdate_or_create(
|
# Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug
|
||||||
|
agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create(
|
||||||
defaults={
|
defaults={
|
||||||
"name": name,
|
"name": name,
|
||||||
"creator": user,
|
"creator": user,
|
||||||
|
|||||||
@@ -114,6 +114,7 @@ class CrossEncoderModel:
|
|||||||
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
|
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
|
||||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||||
response = requests.post(target_url, json=payload, headers=headers)
|
response = requests.post(target_url, json=payload, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
return response.json()["scores"]
|
return response.json()["scores"]
|
||||||
|
|
||||||
cross_inp = [[query, hit.additional[key]] for hit in hits]
|
cross_inp = [[query, hit.additional[key]] for hit in hits]
|
||||||
|
|||||||
@@ -143,7 +143,6 @@ async def read_webpages(
|
|||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
location: LocationData,
|
location: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
subscribed: bool = False,
|
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
uploaded_image_url: str = None,
|
uploaded_image_url: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class ModifyAgentBody(BaseModel):
|
|||||||
files: Optional[List[str]] = []
|
files: Optional[List[str]] = []
|
||||||
input_tools: Optional[List[str]] = []
|
input_tools: Optional[List[str]] = []
|
||||||
output_modes: Optional[List[str]] = []
|
output_modes: Optional[List[str]] = []
|
||||||
|
slug: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@api_agents.get("", response_class=Response)
|
@api_agents.get("", response_class=Response)
|
||||||
@@ -192,6 +193,7 @@ async def create_agent(
|
|||||||
body.files,
|
body.files,
|
||||||
body.input_tools,
|
body.input_tools,
|
||||||
body.output_modes,
|
body.output_modes,
|
||||||
|
body.slug,
|
||||||
)
|
)
|
||||||
|
|
||||||
agents_packet = {
|
agents_packet = {
|
||||||
@@ -233,7 +235,7 @@ async def update_agent(
|
|||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
selected_agent = await AgentAdapters.aget_agent_by_name(body.name, user)
|
selected_agent = await AgentAdapters.aget_agent_by_slug(body.slug, user)
|
||||||
|
|
||||||
if not selected_agent:
|
if not selected_agent:
|
||||||
return Response(
|
return Response(
|
||||||
@@ -253,6 +255,7 @@ async def update_agent(
|
|||||||
body.files,
|
body.files,
|
||||||
body.input_tools,
|
body.input_tools,
|
||||||
body.output_modes,
|
body.output_modes,
|
||||||
|
body.slug,
|
||||||
)
|
)
|
||||||
|
|
||||||
agents_packet = {
|
agents_packet = {
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ def chat_history(
|
|||||||
|
|
||||||
agent_metadata = None
|
agent_metadata = None
|
||||||
if conversation.agent:
|
if conversation.agent:
|
||||||
if conversation.agent.privacy_level == Agent.PrivacyLevel.PRIVATE:
|
if conversation.agent.privacy_level == Agent.PrivacyLevel.PRIVATE and conversation.agent.creator != user:
|
||||||
conversation.agent = None
|
conversation.agent = None
|
||||||
else:
|
else:
|
||||||
agent_metadata = {
|
agent_metadata = {
|
||||||
@@ -853,27 +853,36 @@ async def chat(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# # Gather Context
|
# # Gather Context
|
||||||
# async for result in extract_references_and_questions(
|
# # Extract Document References
|
||||||
# request,
|
# try:
|
||||||
# meta_log,
|
# async for result in extract_references_and_questions(
|
||||||
# q,
|
# request,
|
||||||
# (n or 7),
|
# meta_log,
|
||||||
# d,
|
# q,
|
||||||
# conversation_id,
|
# (n or 7),
|
||||||
# conversation_commands,
|
# d,
|
||||||
# location,
|
# conversation_id,
|
||||||
# partial(send_event, ChatEvent.STATUS),
|
# conversation_commands,
|
||||||
# uploaded_image_url=uploaded_image_url,
|
# location,
|
||||||
# agent=agent,
|
# partial(send_event, ChatEvent.STATUS),
|
||||||
# ):
|
# uploaded_image_url=uploaded_image_url,
|
||||||
# if isinstance(result, dict) and ChatEvent.STATUS in result:
|
# agent=agent,
|
||||||
# yield result[ChatEvent.STATUS]
|
# ):
|
||||||
# else:
|
# if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
# compiled_references.extend(result[0])
|
# yield result[ChatEvent.STATUS]
|
||||||
# inferred_queries.extend(result[1])
|
# else:
|
||||||
# defiltered_query = result[2]
|
# compiled_references.extend(result[0])
|
||||||
|
# inferred_queries.extend(result[1])
|
||||||
# if not is_none_or_empty(compiled_references):
|
# defiltered_query = result[2]
|
||||||
|
# except Exception as e:
|
||||||
|
# error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
||||||
|
# logger.warning(error_message)
|
||||||
|
# async for result in send_event(
|
||||||
|
# ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
||||||
|
# ):
|
||||||
|
# yield result
|
||||||
|
#
|
||||||
|
# # if not is_none_or_empty(compiled_references):
|
||||||
# try:
|
# try:
|
||||||
# headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
# headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
||||||
# # Strip only leading # from headings
|
# # Strip only leading # from headings
|
||||||
@@ -910,12 +919,13 @@ async def chat(
|
|||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
else:
|
else:
|
||||||
online_results = result
|
online_results = result
|
||||||
except ValueError as e:
|
except Exception as e:
|
||||||
error_message = f"Error searching online: {e}. Attempting to respond without online results"
|
error_message = f"Error searching online: {e}. Attempting to respond without online results"
|
||||||
logger.warning(error_message)
|
logger.warning(error_message)
|
||||||
async for result in send_llm_response(error_message):
|
async for result in send_event(
|
||||||
|
ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
|
||||||
|
):
|
||||||
yield result
|
yield result
|
||||||
return
|
|
||||||
|
|
||||||
## Gather Webpage References
|
## Gather Webpage References
|
||||||
if ConversationCommand.Webpage in conversation_commands and pending_research:
|
if ConversationCommand.Webpage in conversation_commands and pending_research:
|
||||||
@@ -925,7 +935,6 @@ async def chat(
|
|||||||
meta_log,
|
meta_log,
|
||||||
location,
|
location,
|
||||||
user,
|
user,
|
||||||
subscribed,
|
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
@@ -945,11 +954,15 @@ async def chat(
|
|||||||
webpages.append(webpage["link"])
|
webpages.append(webpage["link"])
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
|
async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
|
||||||
yield result
|
yield result
|
||||||
except ValueError as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Error directly reading webpages: {e}. Attempting to respond without online results",
|
f"Error reading webpages: {e}. Attempting to respond without webpage results",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
async for result in send_event(
|
||||||
|
ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
|
||||||
|
):
|
||||||
|
yield result
|
||||||
|
|
||||||
## Gather Code Results
|
## Gather Code Results
|
||||||
if ConversationCommand.Code in conversation_commands and pending_research:
|
if ConversationCommand.Code in conversation_commands and pending_research:
|
||||||
|
|||||||
@@ -345,13 +345,13 @@ async def aget_relevant_information_sources(
|
|||||||
final_response = [ConversationCommand.Default]
|
final_response = [ConversationCommand.Default]
|
||||||
else:
|
else:
|
||||||
final_response = [ConversationCommand.General]
|
final_response = [ConversationCommand.General]
|
||||||
return final_response
|
except Exception:
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Invalid response for determining relevant tools: {response}")
|
logger.error(f"Invalid response for determining relevant tools: {response}")
|
||||||
if len(agent_tools) == 0:
|
if len(agent_tools) == 0:
|
||||||
final_response = [ConversationCommand.Default]
|
final_response = [ConversationCommand.Default]
|
||||||
else:
|
else:
|
||||||
final_response = agent_tools
|
final_response = agent_tools
|
||||||
|
return final_response
|
||||||
|
|
||||||
|
|
||||||
async def aget_relevant_output_modes(
|
async def aget_relevant_output_modes(
|
||||||
|
|||||||
@@ -227,7 +227,6 @@ async def execute_information_collection(
|
|||||||
conversation_history,
|
conversation_history,
|
||||||
location,
|
location,
|
||||||
user,
|
user,
|
||||||
subscribed,
|
|
||||||
send_status_func,
|
send_status_func,
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import math
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Type, Union
|
from typing import List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from sentence_transformers import util
|
from sentence_transformers import util
|
||||||
@@ -231,8 +232,12 @@ def setup(
|
|||||||
|
|
||||||
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
|
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
|
||||||
"""Score all retrieved entries using the cross-encoder"""
|
"""Score all retrieved entries using the cross-encoder"""
|
||||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
try:
|
||||||
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
|
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||||
|
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
logger.error(f"Failed to rerank documents using the inference endpoint. Error: {e}.", exc_info=True)
|
||||||
|
cross_scores = [0.0] * len(hits)
|
||||||
|
|
||||||
# Convert cross-encoder scores to distances and pass in hits for reranking
|
# Convert cross-encoder scores to distances and pass in hits for reranking
|
||||||
for idx in range(len(cross_scores)):
|
for idx in range(len(cross_scores)):
|
||||||
|
|||||||
Reference in New Issue
Block a user