mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Reuse Search Models across Content Types to reduce Memory Consumption - Memory consumption now only scales with search models used, not with content types. Previously each content type had it's own copy of the search ML models. That'd result in 300+ Mb per enabled text content type - Split model state into 2 separate state objects, `search_models` and `content_index`. This allows loading text_search and image_search models first and then reusing them across all content_types in content_index - The change should cut down memory utilization quite a bit for most users. I see a >50% drop in memory utilization on my Khoj instance. But this will vary for each user based on the amount of content indexed vs number of plugins enabled. - This change does not solve the RAM utilization scaling with size of the index, as the whole content index is still kept in RAM while Khoj is running Should help with #195, #301 and #303
677 lines
23 KiB
Python
677 lines
23 KiB
Python
# Standard Packages
|
|
import concurrent.futures
|
|
import math
|
|
import time
|
|
import yaml
|
|
import logging
|
|
import json
|
|
from typing import List, Optional, Union
|
|
|
|
# External Packages
|
|
from fastapi import APIRouter, HTTPException, Header, Request
|
|
from sentence_transformers import util
|
|
|
|
# Internal Packages
|
|
from khoj.configure import configure_content, configure_processor, configure_search
|
|
from khoj.search_type import image_search, text_search
|
|
from khoj.search_filter.date_filter import DateFilter
|
|
from khoj.search_filter.file_filter import FileFilter
|
|
from khoj.search_filter.word_filter import WordFilter
|
|
from khoj.utils.config import TextSearchModel
|
|
from khoj.utils.helpers import log_telemetry, timer
|
|
from khoj.utils.rawconfig import (
|
|
ContentConfig,
|
|
FullConfig,
|
|
ProcessorConfig,
|
|
SearchConfig,
|
|
SearchResponse,
|
|
TextContentConfig,
|
|
ConversationProcessorConfig,
|
|
GithubContentConfig,
|
|
NotionContentConfig,
|
|
)
|
|
from khoj.utils.state import SearchType
|
|
from khoj.utils import state, constants
|
|
from khoj.utils.yaml import save_config_to_file_updated_state
|
|
from fastapi.responses import StreamingResponse, Response
|
|
from khoj.routers.helpers import perform_chat_checks, generate_chat_response, update_telemetry_state
|
|
from khoj.processor.conversation.gpt import extract_questions
|
|
from fastapi.requests import Request
|
|
|
|
|
|
# Initialize Router
|
|
api = APIRouter()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# If it's a demo instance, prevent updating any of the configuration.
|
|
if not state.demo:
|
|
|
|
def _initialize_config():
|
|
if state.config is None:
|
|
state.config = FullConfig()
|
|
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
|
|
|
@api.get("/config/data", response_model=FullConfig)
|
|
def get_config_data():
|
|
return state.config
|
|
|
|
@api.post("/config/data")
|
|
async def set_config_data(
|
|
request: Request,
|
|
updated_config: FullConfig,
|
|
client: Optional[str] = None,
|
|
):
|
|
state.config = updated_config
|
|
with open(state.config_file, "w") as outfile:
|
|
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
|
|
outfile.close()
|
|
|
|
configuration_update_metadata = dict()
|
|
|
|
if state.config.content_type is not None:
|
|
configuration_update_metadata["github"] = state.config.content_type.github is not None
|
|
configuration_update_metadata["notion"] = state.config.content_type.notion is not None
|
|
configuration_update_metadata["org"] = state.config.content_type.org is not None
|
|
configuration_update_metadata["pdf"] = state.config.content_type.pdf is not None
|
|
configuration_update_metadata["markdown"] = state.config.content_type.markdown is not None
|
|
configuration_update_metadata["plugins"] = state.config.content_type.plugins is not None
|
|
|
|
if state.config.processor is not None:
|
|
configuration_update_metadata["conversation_processor"] = state.config.processor.conversation is not None
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="set_config",
|
|
client=client,
|
|
metadata=configuration_update_metadata,
|
|
)
|
|
return state.config
|
|
|
|
@api.post("/config/data/content_type/github", status_code=200)
|
|
async def set_content_config_github_data(
|
|
request: Request,
|
|
updated_config: Union[GithubContentConfig, None],
|
|
client: Optional[str] = None,
|
|
):
|
|
_initialize_config()
|
|
|
|
if not state.config.content_type:
|
|
state.config.content_type = ContentConfig(**{"github": updated_config})
|
|
else:
|
|
state.config.content_type.github = updated_config
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="set_content_config",
|
|
client=client,
|
|
metadata={"content_type": "github"},
|
|
)
|
|
|
|
try:
|
|
save_config_to_file_updated_state()
|
|
return {"status": "ok"}
|
|
except Exception as e:
|
|
return {"status": "error", "message": str(e)}
|
|
|
|
@api.post("/config/data/content_type/notion", status_code=200)
|
|
async def set_content_config_notion_data(
|
|
request: Request,
|
|
updated_config: Union[NotionContentConfig, None],
|
|
client: Optional[str] = None,
|
|
):
|
|
_initialize_config()
|
|
|
|
if not state.config.content_type:
|
|
state.config.content_type = ContentConfig(**{"notion": updated_config})
|
|
else:
|
|
state.config.content_type.notion = updated_config
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="set_content_config",
|
|
client=client,
|
|
metadata={"content_type": "notion"},
|
|
)
|
|
|
|
try:
|
|
save_config_to_file_updated_state()
|
|
return {"status": "ok"}
|
|
except Exception as e:
|
|
return {"status": "error", "message": str(e)}
|
|
|
|
@api.post("/delete/config/data/content_type/{content_type}", status_code=200)
|
|
async def remove_content_config_data(
|
|
request: Request,
|
|
content_type: str,
|
|
client: Optional[str] = None,
|
|
):
|
|
if not state.config or not state.config.content_type:
|
|
return {"status": "ok"}
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="delete_content_config",
|
|
client=client,
|
|
metadata={"content_type": content_type},
|
|
)
|
|
|
|
if state.config.content_type:
|
|
state.config.content_type[content_type] = None
|
|
|
|
if content_type == "github":
|
|
state.content_index.github = None
|
|
elif content_type == "notion":
|
|
state.content_index.notion = None
|
|
elif content_type == "plugins":
|
|
state.content_index.plugins = None
|
|
elif content_type == "pdf":
|
|
state.content_index.pdf = None
|
|
elif content_type == "markdown":
|
|
state.content_index.markdown = None
|
|
elif content_type == "org":
|
|
state.content_index.org = None
|
|
|
|
try:
|
|
save_config_to_file_updated_state()
|
|
return {"status": "ok"}
|
|
except Exception as e:
|
|
return {"status": "error", "message": str(e)}
|
|
|
|
@api.post("/delete/config/data/processor/conversation", status_code=200)
|
|
async def remove_processor_conversation_config_data(
|
|
request: Request,
|
|
client: Optional[str] = None,
|
|
):
|
|
if not state.config or not state.config.processor or not state.config.processor.conversation:
|
|
return {"status": "ok"}
|
|
|
|
state.config.processor.conversation = None
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="delete_processor_config",
|
|
client=client,
|
|
metadata={"processor_type": "conversation"},
|
|
)
|
|
|
|
try:
|
|
save_config_to_file_updated_state()
|
|
return {"status": "ok"}
|
|
except Exception as e:
|
|
return {"status": "error", "message": str(e)}
|
|
|
|
@api.post("/config/data/content_type/{content_type}", status_code=200)
|
|
async def set_content_config_data(
|
|
request: Request,
|
|
content_type: str,
|
|
updated_config: Union[TextContentConfig, None],
|
|
client: Optional[str] = None,
|
|
):
|
|
_initialize_config()
|
|
|
|
if not state.config.content_type:
|
|
state.config.content_type = ContentConfig(**{content_type: updated_config})
|
|
else:
|
|
state.config.content_type[content_type] = updated_config
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="set_content_config",
|
|
client=client,
|
|
metadata={"content_type": content_type},
|
|
)
|
|
|
|
try:
|
|
save_config_to_file_updated_state()
|
|
return {"status": "ok"}
|
|
except Exception as e:
|
|
return {"status": "error", "message": str(e)}
|
|
|
|
@api.post("/config/data/processor/conversation", status_code=200)
|
|
async def set_processor_conversation_config_data(
|
|
request: Request,
|
|
updated_config: Union[ConversationProcessorConfig, None],
|
|
client: Optional[str] = None,
|
|
):
|
|
_initialize_config()
|
|
|
|
state.config.processor = ProcessorConfig(conversation=updated_config)
|
|
state.processor_config = configure_processor(state.config.processor)
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="set_content_config",
|
|
client=client,
|
|
metadata={"processor_type": "conversation"},
|
|
)
|
|
|
|
try:
|
|
save_config_to_file_updated_state()
|
|
return {"status": "ok"}
|
|
except Exception as e:
|
|
return {"status": "error", "message": str(e)}
|
|
|
|
|
|
# Create Routes
|
|
@api.get("/config/data/default")
|
|
def get_default_config_data():
|
|
return constants.default_config
|
|
|
|
|
|
@api.get("/config/types", response_model=List[str])
|
|
def get_config_types():
|
|
"""Get configured content types"""
|
|
if state.config is None or state.config.content_type is None:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="Content types not configured. Configure at least one content type on server and restart it.",
|
|
)
|
|
|
|
configured_content_types = state.config.content_type.dict(exclude_none=True)
|
|
return [
|
|
search_type.value
|
|
for search_type in SearchType
|
|
if (
|
|
search_type.value in configured_content_types
|
|
and getattr(state.content_index, search_type.value) is not None
|
|
)
|
|
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
|
|
or search_type == SearchType.All
|
|
]
|
|
|
|
|
|
@api.get("/search", response_model=List[SearchResponse])
|
|
async def search(
|
|
q: str,
|
|
request: Request,
|
|
n: Optional[int] = 5,
|
|
t: Optional[SearchType] = SearchType.All,
|
|
r: Optional[bool] = False,
|
|
score_threshold: Optional[Union[float, None]] = None,
|
|
dedupe: Optional[bool] = True,
|
|
client: Optional[str] = None,
|
|
user_agent: Optional[str] = Header(None),
|
|
referer: Optional[str] = Header(None),
|
|
host: Optional[str] = Header(None),
|
|
):
|
|
start_time = time.time()
|
|
|
|
# Run validation checks
|
|
results: List[SearchResponse] = []
|
|
if q is None or q == "":
|
|
logger.warning(f"No query param (q) passed in API call to initiate search")
|
|
return results
|
|
if not state.search_models or not any(state.search_models.__dict__.values()):
|
|
logger.warning(f"No search models loaded. Configure a search model before initiating search")
|
|
return results
|
|
|
|
# initialize variables
|
|
user_query = q.strip()
|
|
results_count = n or 5
|
|
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
|
search_futures: List[concurrent.futures.Future] = []
|
|
|
|
# return cached results, if available
|
|
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
|
if query_cache_key in state.query_cache:
|
|
logger.debug(f"Return response from query cache")
|
|
return state.query_cache[query_cache_key]
|
|
|
|
# Encode query with filter terms removed
|
|
defiltered_query = user_query
|
|
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
|
defiltered_query = filter.defilter(user_query)
|
|
|
|
encoded_asymmetric_query = None
|
|
if t == SearchType.All or t != SearchType.Image:
|
|
text_search_models: List[TextSearchModel] = [
|
|
model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel)
|
|
]
|
|
if text_search_models:
|
|
with timer("Encoding query took", logger=logger):
|
|
encoded_asymmetric_query = util.normalize_embeddings(
|
|
text_search_models[0].bi_encoder.encode(
|
|
[defiltered_query],
|
|
convert_to_tensor=True,
|
|
device=state.device,
|
|
)
|
|
)
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search:
|
|
# query org-mode notes
|
|
search_futures += [
|
|
executor.submit(
|
|
text_search.query,
|
|
user_query,
|
|
state.search_models.text_search,
|
|
state.content_index.org,
|
|
question_embedding=encoded_asymmetric_query,
|
|
rank_results=r or False,
|
|
score_threshold=score_threshold,
|
|
dedupe=dedupe or True,
|
|
)
|
|
]
|
|
|
|
if (
|
|
(t == SearchType.Markdown or t == SearchType.All)
|
|
and state.content_index.markdown
|
|
and state.search_models.text_search
|
|
):
|
|
# query markdown notes
|
|
search_futures += [
|
|
executor.submit(
|
|
text_search.query,
|
|
user_query,
|
|
state.search_models.text_search,
|
|
state.content_index.markdown,
|
|
question_embedding=encoded_asymmetric_query,
|
|
rank_results=r or False,
|
|
score_threshold=score_threshold,
|
|
dedupe=dedupe or True,
|
|
)
|
|
]
|
|
|
|
if (
|
|
(t == SearchType.Github or t == SearchType.All)
|
|
and state.content_index.github
|
|
and state.search_models.text_search
|
|
):
|
|
# query github issues
|
|
search_futures += [
|
|
executor.submit(
|
|
text_search.query,
|
|
user_query,
|
|
state.search_models.text_search,
|
|
state.content_index.github,
|
|
question_embedding=encoded_asymmetric_query,
|
|
rank_results=r or False,
|
|
score_threshold=score_threshold,
|
|
dedupe=dedupe or True,
|
|
)
|
|
]
|
|
|
|
if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
|
|
# query pdf files
|
|
search_futures += [
|
|
executor.submit(
|
|
text_search.query,
|
|
user_query,
|
|
state.search_models.text_search,
|
|
state.content_index.pdf,
|
|
question_embedding=encoded_asymmetric_query,
|
|
rank_results=r or False,
|
|
score_threshold=score_threshold,
|
|
dedupe=dedupe or True,
|
|
)
|
|
]
|
|
|
|
if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
|
|
# query images
|
|
search_futures += [
|
|
executor.submit(
|
|
image_search.query,
|
|
user_query,
|
|
results_count,
|
|
state.search_models.image_search,
|
|
state.content_index.image,
|
|
score_threshold=score_threshold,
|
|
)
|
|
]
|
|
|
|
if (
|
|
(t == SearchType.All or t in SearchType)
|
|
and state.content_index.plugins
|
|
and state.search_models.plugin_search
|
|
):
|
|
# query specified plugin type
|
|
# Get plugin content, search model for specified search type, or the first one if none specified
|
|
plugin_search = state.search_models.plugin_search.get(t.value) or next(
|
|
iter(state.search_models.plugin_search.values())
|
|
)
|
|
plugin_content = state.content_index.plugins.get(t.value) or next(
|
|
iter(state.content_index.plugins.values())
|
|
)
|
|
search_futures += [
|
|
executor.submit(
|
|
text_search.query,
|
|
user_query,
|
|
plugin_search,
|
|
plugin_content,
|
|
question_embedding=encoded_asymmetric_query,
|
|
rank_results=r or False,
|
|
score_threshold=score_threshold,
|
|
dedupe=dedupe or True,
|
|
)
|
|
]
|
|
|
|
if (
|
|
(t == SearchType.Notion or t == SearchType.All)
|
|
and state.content_index.notion
|
|
and state.search_models.text_search
|
|
):
|
|
# query notion pages
|
|
search_futures += [
|
|
executor.submit(
|
|
text_search.query,
|
|
user_query,
|
|
state.search_models.text_search,
|
|
state.content_index.notion,
|
|
question_embedding=encoded_asymmetric_query,
|
|
rank_results=r or False,
|
|
score_threshold=score_threshold,
|
|
dedupe=dedupe or True,
|
|
)
|
|
]
|
|
|
|
# Query across each requested content types in parallel
|
|
with timer("Query took", logger):
|
|
for search_future in concurrent.futures.as_completed(search_futures):
|
|
if t == SearchType.Image and state.content_index.image:
|
|
hits = await search_future.result()
|
|
output_directory = constants.web_directory / "images"
|
|
# Collate results
|
|
results += image_search.collate_results(
|
|
hits,
|
|
image_names=state.content_index.image.image_names,
|
|
output_directory=output_directory,
|
|
image_files_url="/static/images",
|
|
count=results_count,
|
|
)
|
|
else:
|
|
hits, entries = await search_future.result()
|
|
# Collate results
|
|
results += text_search.collate_results(hits, entries, results_count)
|
|
|
|
# Sort results across all content types and take top results
|
|
results = sorted(results, key=lambda x: float(x.score), reverse=True)[:results_count]
|
|
|
|
# Cache results
|
|
state.query_cache[query_cache_key] = results
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="search",
|
|
client=client,
|
|
user_agent=user_agent,
|
|
referer=referer,
|
|
host=host,
|
|
)
|
|
|
|
state.previous_query = user_query
|
|
|
|
end_time = time.time()
|
|
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
|
|
|
return results
|
|
|
|
|
|
@api.get("/update")
|
|
def update(
|
|
request: Request,
|
|
t: Optional[SearchType] = None,
|
|
force: Optional[bool] = False,
|
|
client: Optional[str] = None,
|
|
user_agent: Optional[str] = Header(None),
|
|
referer: Optional[str] = Header(None),
|
|
host: Optional[str] = Header(None),
|
|
):
|
|
try:
|
|
state.search_index_lock.acquire()
|
|
try:
|
|
if state.config and state.config.search_type:
|
|
state.search_models = configure_search(state.search_models, state.config.search_type)
|
|
if state.search_models:
|
|
state.content_index = configure_content(
|
|
state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t
|
|
)
|
|
except Exception as e:
|
|
logger.error(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
finally:
|
|
state.search_index_lock.release()
|
|
except ValueError as e:
|
|
logger.error(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
else:
|
|
logger.info("📬 Search index updated via API")
|
|
|
|
try:
|
|
if state.config and state.config.processor:
|
|
state.processor_config = configure_processor(state.config.processor)
|
|
except ValueError as e:
|
|
logger.error(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
else:
|
|
logger.info("📬 Processor reconfigured via API")
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="update",
|
|
client=client,
|
|
user_agent=user_agent,
|
|
referer=referer,
|
|
host=host,
|
|
)
|
|
|
|
return {"status": "ok", "message": "khoj reloaded"}
|
|
|
|
|
|
@api.get("/chat/history")
|
|
def chat_history(
|
|
request: Request,
|
|
client: Optional[str] = None,
|
|
user_agent: Optional[str] = Header(None),
|
|
referer: Optional[str] = Header(None),
|
|
host: Optional[str] = Header(None),
|
|
):
|
|
perform_chat_checks()
|
|
|
|
# Load Conversation History
|
|
meta_log = state.processor_config.conversation.meta_log
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="chat",
|
|
client=client,
|
|
user_agent=user_agent,
|
|
referer=referer,
|
|
host=host,
|
|
)
|
|
|
|
return {"status": "ok", "response": meta_log.get("chat", [])}
|
|
|
|
|
|
@api.get("/chat", response_class=Response)
|
|
async def chat(
|
|
request: Request,
|
|
q: str,
|
|
n: Optional[int] = 5,
|
|
client: Optional[str] = None,
|
|
stream: Optional[bool] = False,
|
|
user_agent: Optional[str] = Header(None),
|
|
referer: Optional[str] = Header(None),
|
|
host: Optional[str] = Header(None),
|
|
) -> Response:
|
|
perform_chat_checks()
|
|
compiled_references, inferred_queries = await extract_references_and_questions(request, q, (n or 5))
|
|
|
|
# Get the (streamed) chat response from GPT.
|
|
gpt_response = generate_chat_response(
|
|
q,
|
|
meta_log=state.processor_config.conversation.meta_log,
|
|
compiled_references=compiled_references,
|
|
inferred_queries=inferred_queries,
|
|
)
|
|
if gpt_response is None:
|
|
return Response(content=gpt_response, media_type="text/plain", status_code=500)
|
|
|
|
if stream:
|
|
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
|
|
|
|
# Get the full response from the generator if the stream is not requested.
|
|
aggregated_gpt_response = ""
|
|
while True:
|
|
try:
|
|
aggregated_gpt_response += next(gpt_response)
|
|
except StopIteration:
|
|
break
|
|
|
|
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
|
|
|
|
response_obj = {"response": actual_response, "context": compiled_references}
|
|
|
|
update_telemetry_state(
|
|
request=request,
|
|
telemetry_type="api",
|
|
api="chat",
|
|
client=client,
|
|
user_agent=user_agent,
|
|
referer=referer,
|
|
host=host,
|
|
)
|
|
|
|
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
|
|
|
|
|
|
async def extract_references_and_questions(
|
|
request: Request,
|
|
q: str,
|
|
n: int,
|
|
):
|
|
# Load Conversation History
|
|
meta_log = state.processor_config.conversation.meta_log
|
|
|
|
# Initialize Variables
|
|
api_key = state.processor_config.conversation.openai_api_key
|
|
chat_model = state.processor_config.conversation.chat_model
|
|
conversation_type = "general" if q.startswith("@general") else "notes"
|
|
compiled_references = []
|
|
inferred_queries = []
|
|
|
|
if conversation_type == "notes":
|
|
# Infer search queries from user message
|
|
with timer("Extracting search queries took", logger):
|
|
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
|
|
|
|
# Collate search results as context for GPT
|
|
with timer("Searching knowledge base took", logger):
|
|
result_list = []
|
|
for query in inferred_queries:
|
|
result_list.extend(
|
|
await search(query, request=request, n=n, r=True, score_threshold=-5.0, dedupe=False)
|
|
)
|
|
compiled_references = [item.additional["compiled"] for item in result_list]
|
|
|
|
return compiled_references, inferred_queries
|