diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index 71af5b7d..a19d85fa 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -114,6 +114,7 @@ class CrossEncoderModel: payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}} headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} response = requests.post(target_url, json=payload, headers=headers) + response.raise_for_status() return response.json()["scores"] cross_inp = [[query, hit.additional[key]] for hit in hits] diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 94a069da..03bf5f50 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -3,7 +3,6 @@ import base64 import json import logging import time -import warnings from datetime import datetime from functools import partial from typing import Dict, Optional @@ -840,25 +839,33 @@ async def chat( # Gather Context ## Extract Document References compiled_references, inferred_queries, defiltered_query = [], [], None - async for result in extract_references_and_questions( - request, - meta_log, - q, - (n or 7), - d, - conversation_id, - conversation_commands, - location, - partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - agent=agent, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - compiled_references.extend(result[0]) - inferred_queries.extend(result[1]) - defiltered_query = result[2] + try: + async for result in extract_references_and_questions( + request, + meta_log, + q, + (n or 7), + d, + conversation_id, + conversation_commands, + location, + partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + except Exception as e: + error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references." + logger.warning(error_message) + async for result in send_event( + ChatEvent.STATUS, "Document search failed. I'll try respond without document references" + ): + yield result if not is_none_or_empty(compiled_references): headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 52e23f29..b67132e4 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -3,6 +3,7 @@ import math from pathlib import Path from typing import List, Optional, Tuple, Type, Union +import requests import torch from asgiref.sync import sync_to_async from sentence_transformers import util @@ -231,8 +232,12 @@ def setup( def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]: """Score all retrieved entries using the cross-encoder""" - with timer("Cross-Encoder Predict Time", logger, state.device): - cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits) + try: + with timer("Cross-Encoder Predict Time", logger, state.device): + cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits) + except requests.exceptions.HTTPError as e: + logger.error(f"Failed to rerank documents using the inference endpoint. Error: {e}.", exc_info=True) + cross_scores = [0.0] * len(hits) # Convert cross-encoder scores to distances and pass in hits for reranking for idx in range(len(cross_scores)):