mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Try respond even if document search via inference endpoint fails
The huggingface endpoint can be flaky. Khoj shouldn't refuse to respond to user if document search fails. It should transparently mention that document lookup failed. But try respond as best as it can without the document references This changes provides graceful failover when inference endpoint requests fail either when encoding query or reranking retrieved docs
This commit is contained in:
@@ -114,6 +114,7 @@ class CrossEncoderModel:
|
|||||||
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
|
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
|
||||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||||
response = requests.post(target_url, json=payload, headers=headers)
|
response = requests.post(target_url, json=payload, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
return response.json()["scores"]
|
return response.json()["scores"]
|
||||||
|
|
||||||
cross_inp = [[query, hit.additional[key]] for hit in hits]
|
cross_inp = [[query, hit.additional[key]] for hit in hits]
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import warnings
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
@@ -840,25 +839,33 @@ async def chat(
|
|||||||
# Gather Context
|
# Gather Context
|
||||||
## Extract Document References
|
## Extract Document References
|
||||||
compiled_references, inferred_queries, defiltered_query = [], [], None
|
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||||
async for result in extract_references_and_questions(
|
try:
|
||||||
request,
|
async for result in extract_references_and_questions(
|
||||||
meta_log,
|
request,
|
||||||
q,
|
meta_log,
|
||||||
(n or 7),
|
q,
|
||||||
d,
|
(n or 7),
|
||||||
conversation_id,
|
d,
|
||||||
conversation_commands,
|
conversation_id,
|
||||||
location,
|
conversation_commands,
|
||||||
partial(send_event, ChatEvent.STATUS),
|
location,
|
||||||
uploaded_image_url=uploaded_image_url,
|
partial(send_event, ChatEvent.STATUS),
|
||||||
agent=agent,
|
uploaded_image_url=uploaded_image_url,
|
||||||
):
|
agent=agent,
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
):
|
||||||
yield result[ChatEvent.STATUS]
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
else:
|
yield result[ChatEvent.STATUS]
|
||||||
compiled_references.extend(result[0])
|
else:
|
||||||
inferred_queries.extend(result[1])
|
compiled_references.extend(result[0])
|
||||||
defiltered_query = result[2]
|
inferred_queries.extend(result[1])
|
||||||
|
defiltered_query = result[2]
|
||||||
|
except Exception as e:
|
||||||
|
error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
||||||
|
logger.warning(error_message)
|
||||||
|
async for result in send_event(
|
||||||
|
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
||||||
|
):
|
||||||
|
yield result
|
||||||
|
|
||||||
if not is_none_or_empty(compiled_references):
|
if not is_none_or_empty(compiled_references):
|
||||||
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import math
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Type, Union
|
from typing import List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from sentence_transformers import util
|
from sentence_transformers import util
|
||||||
@@ -231,8 +232,12 @@ def setup(
|
|||||||
|
|
||||||
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
|
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
|
||||||
"""Score all retrieved entries using the cross-encoder"""
|
"""Score all retrieved entries using the cross-encoder"""
|
||||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
try:
|
||||||
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
|
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||||
|
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
logger.error(f"Failed to rerank documents using the inference endpoint. Error: {e}.", exc_info=True)
|
||||||
|
cross_scores = [0.0] * len(hits)
|
||||||
|
|
||||||
# Convert cross-encoder scores to distances and pass in hits for reranking
|
# Convert cross-encoder scores to distances and pass in hits for reranking
|
||||||
for idx in range(len(cross_scores)):
|
for idx in range(len(cross_scores)):
|
||||||
|
|||||||
Reference in New Issue
Block a user