From 1b04b801c6e6e62a09a16d4a5eddff7dbafe9590 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 14 Oct 2024 17:39:44 -0700 Subject: [PATCH] Try respond even if document search via inference endpoint fails The huggingface endpoint can be flaky. Khoj shouldn't refuse to respond to user if document search fails. It should transparently mention that document lookup failed. But try respond as best as it can without the document references This changes provides graceful failover when inference endpoint requests fail either when encoding query or reranking retrieved docs --- src/khoj/processor/embeddings.py | 1 + src/khoj/routers/api_chat.py | 47 +++++++++++++++++------------ src/khoj/search_type/text_search.py | 9 ++++-- 3 files changed, 35 insertions(+), 22 deletions(-) diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index 71af5b7d..a19d85fa 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -114,6 +114,7 @@ class CrossEncoderModel: payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}} headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} response = requests.post(target_url, json=payload, headers=headers) + response.raise_for_status() return response.json()["scores"] cross_inp = [[query, hit.additional[key]] for hit in hits] diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 9022a7dc..93c905b6 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -3,7 +3,6 @@ import base64 import json import logging import time -import warnings from datetime import datetime from functools import partial from typing import Dict, Optional @@ -839,25 +838,33 @@ async def chat( # Gather Context ## Extract Document References compiled_references, inferred_queries, defiltered_query = [], [], None - async for result in extract_references_and_questions( - request, - meta_log, - q, - (n or 7), - d, - conversation_id, - conversation_commands, - location, - partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - agent=agent, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - compiled_references.extend(result[0]) - inferred_queries.extend(result[1]) - defiltered_query = result[2] + try: + async for result in extract_references_and_questions( + request, + meta_log, + q, + (n or 7), + d, + conversation_id, + conversation_commands, + location, + partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + except Exception as e: + error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references." + logger.warning(error_message) + async for result in send_event( + ChatEvent.STATUS, "Document search failed. I'll try respond without document references" + ): + yield result if not is_none_or_empty(compiled_references): headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 52e23f29..b67132e4 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -3,6 +3,7 @@ import math from pathlib import Path from typing import List, Optional, Tuple, Type, Union +import requests import torch from asgiref.sync import sync_to_async from sentence_transformers import util @@ -231,8 +232,12 @@ def setup( def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]: """Score all retrieved entries using the cross-encoder""" - with timer("Cross-Encoder Predict Time", logger, state.device): - cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits) + try: + with timer("Cross-Encoder Predict Time", logger, state.device): + cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits) + except requests.exceptions.HTTPError as e: + logger.error(f"Failed to rerank documents using the inference endpoint. Error: {e}.", exc_info=True) + cross_scores = [0.0] * len(hits) # Convert cross-encoder scores to distances and pass in hits for reranking for idx in range(len(cross_scores)):