Reuse Search Models across Content Types to Reduce Memory Consumption

- Memory consumption now only scales with search models used, not with
  content types as well. Previously each content type had it's own
  copy of the search ML models. That'd result in 300+ Mb per enabled
  content type

- Split model state into 2 separate state objects, `search_models' and
  `content_index'.
  This allows loading text_search and image_search models first and then
  reusing them across all content_types in content_index

- This should cut down memory utilization quite a bit for most users.
  I see a ~50% drop in memory utilization.

  This will, of course, vary for each user based on the amount of
  content indexed vs number of plugins enabled

- This does not solve the RAM utilization scaling with size of the index.
  As the whole content index is still kept in RAM while Khoj is running

Should help with #195, #301 and #303
This commit is contained in:
Debanjum Singh Solanky
2023-07-14 01:07:44 -07:00
parent c2249eadb2
commit 86e2bec9a0
8 changed files with 217 additions and 142 deletions

View File

@@ -20,9 +20,15 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from khoj.utils.config import (
ContentIndex,
SearchType,
SearchModels,
ProcessorConfigModel,
ConversationProcessorConfigModel,
)
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
from khoj.utils.rawconfig import FullConfig, ProcessorConfig from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
@@ -49,10 +55,18 @@ def configure_server(args, required=False):
# Initialize Processor from Config # Initialize Processor from Config
state.processor_config = configure_processor(args.config.processor) state.processor_config = configure_processor(args.config.processor)
# Initialize the search type and model from Config # Initialize Search Models from Config
state.search_index_lock.acquire() state.search_index_lock.acquire()
state.SearchType = configure_search_types(state.config) state.SearchType = configure_search_types(state.config)
state.model = configure_search(state.model, state.config, args.regenerate) state.search_models = configure_search(state.search_models, state.config.search_type)
state.search_index_lock.release()
# Initialize Content from Config
if state.search_models:
state.search_index_lock.acquire()
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, args.regenerate
)
state.search_index_lock.release() state.search_index_lock.release()
@@ -73,7 +87,9 @@ if not state.demo:
@schedule.repeat(schedule.every(61).minutes) @schedule.repeat(schedule.every(61).minutes)
def update_search_index(): def update_search_index():
state.search_index_lock.acquire() state.search_index_lock.acquire()
state.model = configure_search(state.model, state.config, regenerate=False) state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, regenerate=False
)
state.search_index_lock.release() state.search_index_lock.release()
logger.info("📬 Search index updated via Scheduler") logger.info("📬 Search index updated via Scheduler")
@@ -90,94 +106,116 @@ def configure_search_types(config: FullConfig):
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types)) return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None): def configure_search(search_models: SearchModels, search_config: SearchConfig) -> Optional[SearchModels]:
if config is None or config.content_type is None or config.search_type is None: # Run Validation Checks
logger.warning("🚨 No Content or Search type is configured.") if search_config is None:
return logger.warning("🚨 No Search type is configured.")
return None
if search_models is None:
search_models = SearchModels()
if model is None: # Initialize Search Models
model = SearchModels() if search_config.asymmetric:
logger.info("🔍 📜 Setting up text search model")
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
if search_config.image:
logger.info("🔍 🌄 Setting up image search model")
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
def configure_content(
content_index: Optional[ContentIndex],
content_config: Optional[ContentConfig],
search_models: SearchModels,
regenerate: bool,
t: Optional[state.SearchType] = None,
) -> Optional[ContentIndex]:
# Run Validation Checks
if content_config is None:
logger.warning("🚨 No Content type is configured.")
return None
if content_index is None:
content_index = ContentIndex()
try: try:
# Initialize Org Notes Search # Initialize Org Notes Search
if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric: if (t == state.SearchType.Org or t == None) and content_config.org and search_models.text_search:
logger.info("🦄 Setting up search for orgmode notes") logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings # Extract Entries, Generate Notes Embeddings
model.org_search = text_search.setup( content_index.org = text_search.setup(
OrgToJsonl, OrgToJsonl,
config.content_type.org, content_config.org,
search_config=config.search_type.asymmetric, search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()], filters=[DateFilter(), WordFilter(), FileFilter()],
) )
# Initialize Markdown Search # Initialize Markdown Search
if ( if (t == state.SearchType.Markdown or t == None) and content_config.markdown and search_models.text_search:
(t == state.SearchType.Markdown or t == None)
and config.content_type.markdown
and config.search_type.asymmetric
):
logger.info("💎 Setting up search for markdown notes") logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings # Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup( content_index.markdown = text_search.setup(
MarkdownToJsonl, MarkdownToJsonl,
config.content_type.markdown, content_config.markdown,
search_config=config.search_type.asymmetric, search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()], filters=[DateFilter(), WordFilter(), FileFilter()],
) )
# Initialize PDF Search # Initialize PDF Search
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric: if (t == state.SearchType.Pdf or t == None) and content_config.pdf and search_models.text_search:
logger.info("🖨️ Setting up search for pdf") logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings # Extract Entries, Generate PDF Embeddings
model.pdf_search = text_search.setup( content_index.pdf = text_search.setup(
PdfToJsonl, PdfToJsonl,
config.content_type.pdf, content_config.pdf,
search_config=config.search_type.asymmetric, search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()], filters=[DateFilter(), WordFilter(), FileFilter()],
) )
# Initialize Image Search # Initialize Image Search
if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image: if (t == state.SearchType.Image or t == None) and content_config.image and search_models.image_search:
logger.info("🌄 Setting up search for images") logger.info("🌄 Setting up search for images")
# Extract Entries, Generate Image Embeddings # Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup( content_index.image = image_search.setup(
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
) )
if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric: if (t == state.SearchType.Github or t == None) and content_config.github and search_models.text_search:
logger.info("🐙 Setting up search for github") logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings # Extract Entries, Generate Github Embeddings
model.github_search = text_search.setup( content_index.github = text_search.setup(
GithubToJsonl, GithubToJsonl,
config.content_type.github, content_config.github,
search_config=config.search_type.asymmetric, search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()], filters=[DateFilter(), WordFilter(), FileFilter()],
) )
# Initialize External Plugin Search # Initialize External Plugin Search
if (t == None or t in state.SearchType) and config.content_type.plugins: if (t == None or t in state.SearchType) and content_config.plugins and search_models.text_search:
logger.info("🔌 Setting up search for plugins") logger.info("🔌 Setting up search for plugins")
model.plugin_search = {} content_index.plugins = {}
for plugin_type, plugin_config in config.content_type.plugins.items(): for plugin_type, plugin_config in content_config.plugins.items():
model.plugin_search[plugin_type] = text_search.setup( content_index.plugins[plugin_type] = text_search.setup(
JsonlToJsonl, JsonlToJsonl,
plugin_config, plugin_config,
search_config=config.search_type.asymmetric, search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()], filters=[DateFilter(), WordFilter(), FileFilter()],
) )
# Initialize Notion Search # Initialize Notion Search
if (t == None or t in state.SearchType) and config.content_type.notion: if (t == None or t in state.SearchType) and content_config.notion and search_models.text_search:
logger.info("🔌 Setting up search for notion") logger.info("🔌 Setting up search for notion")
model.notion_search = text_search.setup( content_index.notion = text_search.setup(
NotionToJsonl, NotionToJsonl,
config.content_type.notion, content_config.notion,
search_config=config.search_type.asymmetric, search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()], filters=[DateFilter(), WordFilter(), FileFilter()],
) )
@@ -189,7 +227,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
# Invalidate Query Cache # Invalidate Query Cache
state.query_cache = LRU() state.query_cache = LRU()
return model return content_index
def configure_processor(processor_config: ProcessorConfig): def configure_processor(processor_config: ProcessorConfig):

View File

@@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Header, Request
from sentence_transformers import util from sentence_transformers import util
# Internal Packages # Internal Packages
from khoj.configure import configure_processor, configure_search from khoj.configure import configure_content, configure_processor, configure_search
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
@@ -102,17 +102,17 @@ if not state.demo:
state.config.content_type[content_type] = None state.config.content_type[content_type] = None
if content_type == "github": if content_type == "github":
state.model.github_search = None state.content_index.github = None
elif content_type == "notion": elif content_type == "notion":
state.model.notion_search = None state.content_index.notion = None
elif content_type == "plugins": elif content_type == "plugins":
state.model.plugin_search = None state.content_index.plugins = None
elif content_type == "pdf": elif content_type == "pdf":
state.model.pdf_search = None state.content_index.pdf = None
elif content_type == "markdown": elif content_type == "markdown":
state.model.markdown_search = None state.content_index.markdown = None
elif content_type == "org": elif content_type == "org":
state.model.org_search = None state.content_index.org = None
try: try:
save_config_to_file_updated_state() save_config_to_file_updated_state()
@@ -182,7 +182,7 @@ def get_config_types():
for search_type in SearchType for search_type in SearchType
if ( if (
search_type.value in configured_content_types search_type.value in configured_content_types
and getattr(state.model, f"{search_type.value}_search") is not None and getattr(state.content_index, search_type.value) is not None
) )
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"]) or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
or search_type == SearchType.All or search_type == SearchType.All
@@ -210,7 +210,7 @@ async def search(
if q is None or q == "": if q is None or q == "":
logger.warning(f"No query param (q) passed in API call to initiate search") logger.warning(f"No query param (q) passed in API call to initiate search")
return results return results
if not state.model or not any(state.model.__dict__.values()): if not state.search_models or not any(state.search_models.__dict__.values()):
logger.warning(f"No search models loaded. Configure a search model before initiating search") logger.warning(f"No search models loaded. Configure a search model before initiating search")
return results return results
@@ -234,7 +234,7 @@ async def search(
encoded_asymmetric_query = None encoded_asymmetric_query = None
if t == SearchType.All or t != SearchType.Image: if t == SearchType.All or t != SearchType.Image:
text_search_models: List[TextSearchModel] = [ text_search_models: List[TextSearchModel] = [
model for model in state.model.__dict__.values() if isinstance(model, TextSearchModel) model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel)
] ]
if text_search_models: if text_search_models:
with timer("Encoding query took", logger=logger): with timer("Encoding query took", logger=logger):
@@ -247,13 +247,14 @@ async def search(
) )
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search: if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search:
# query org-mode notes # query org-mode notes
search_futures += [ search_futures += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user_query, user_query,
state.model.org_search, state.search_models.text_search,
state.content_index.org,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, rank_results=r or False,
score_threshold=score_threshold, score_threshold=score_threshold,
@@ -261,13 +262,18 @@ async def search(
) )
] ]
if (t == SearchType.Markdown or t == SearchType.All) and state.model.markdown_search: if (
(t == SearchType.Markdown or t == SearchType.All)
and state.content_index.markdown
and state.search_models.text_search
):
# query markdown notes # query markdown notes
search_futures += [ search_futures += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user_query, user_query,
state.model.markdown_search, state.search_models.text_search,
state.content_index.markdown,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, rank_results=r or False,
score_threshold=score_threshold, score_threshold=score_threshold,
@@ -275,13 +281,18 @@ async def search(
) )
] ]
if (t == SearchType.Github or t == SearchType.All) and state.model.github_search: if (
(t == SearchType.Github or t == SearchType.All)
and state.content_index.github
and state.search_models.text_search
):
# query github issues # query github issues
search_futures += [ search_futures += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user_query, user_query,
state.model.github_search, state.search_models.text_search,
state.content_index.github,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, rank_results=r or False,
score_threshold=score_threshold, score_threshold=score_threshold,
@@ -289,13 +300,14 @@ async def search(
) )
] ]
if (t == SearchType.Pdf or t == SearchType.All) and state.model.pdf_search: if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
# query pdf files # query pdf files
search_futures += [ search_futures += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user_query, user_query,
state.model.pdf_search, state.search_models.text_search,
state.content_index.pdf,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, rank_results=r or False,
score_threshold=score_threshold, score_threshold=score_threshold,
@@ -303,26 +315,38 @@ async def search(
) )
] ]
if (t == SearchType.Image) and state.model.image_search: if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
# query images # query images
search_futures += [ search_futures += [
executor.submit( executor.submit(
image_search.query, image_search.query,
user_query, user_query,
results_count, results_count,
state.model.image_search, state.search_models.image_search,
state.content_index.image,
score_threshold=score_threshold, score_threshold=score_threshold,
) )
] ]
if (t == SearchType.All or t in SearchType) and state.model.plugin_search: if (
(t == SearchType.All or t in SearchType)
and state.content_index.plugins
and state.search_models.plugin_search
):
# query specified plugin type # query specified plugin type
# Get plugin content, search model for specified search type, or the first one if none specified
plugin_search = state.search_models.plugin_search.get(t.value) or next(
iter(state.search_models.plugin_search.values())
)
plugin_content = state.content_index.plugins.get(t.value) or next(
iter(state.content_index.plugins.values())
)
search_futures += [ search_futures += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user_query, user_query,
# Get plugin search model for specified search type, or the first one if none specified plugin_search,
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())), plugin_content,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, rank_results=r or False,
score_threshold=score_threshold, score_threshold=score_threshold,
@@ -330,13 +354,18 @@ async def search(
) )
] ]
if (t == SearchType.Notion or t == SearchType.All) and state.model.notion_search: if (
(t == SearchType.Notion or t == SearchType.All)
and state.content_index.notion
and state.search_models.text_search
):
# query notion pages # query notion pages
search_futures += [ search_futures += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user_query, user_query,
state.model.notion_search, state.search_models.text_search,
state.content_index.notion,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, rank_results=r or False,
score_threshold=score_threshold, score_threshold=score_threshold,
@@ -347,13 +376,13 @@ async def search(
# Query across each requested content types in parallel # Query across each requested content types in parallel
with timer("Query took", logger): with timer("Query took", logger):
for search_future in concurrent.futures.as_completed(search_futures): for search_future in concurrent.futures.as_completed(search_futures):
if t == SearchType.Image: if t == SearchType.Image and state.content_index.image:
hits = await search_future.result() hits = await search_future.result()
output_directory = constants.web_directory / "images" output_directory = constants.web_directory / "images"
# Collate results # Collate results
results += image_search.collate_results( results += image_search.collate_results(
hits, hits,
image_names=state.model.image_search.image_names, image_names=state.content_index.image.image_names,
output_directory=output_directory, output_directory=output_directory,
image_files_url="/static/images", image_files_url="/static/images",
count=results_count, count=results_count,
@@ -404,7 +433,12 @@ def update(
try: try:
state.search_index_lock.acquire() state.search_index_lock.acquire()
try: try:
state.model = configure_search(state.model, state.config, regenerate=force or False, t=t) if state.config and state.config.search_type:
state.search_models = configure_search(state.search_models, state.config.search_type)
if state.search_models:
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t
)
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))

View File

@@ -12,10 +12,12 @@ from sentence_transformers import SentenceTransformer, util
from PIL import Image from PIL import Image
from tqdm import trange from tqdm import trange
import torch import torch
from khoj.utils import state
# Internal Packages # Internal Packages
from khoj.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer from khoj.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
from khoj.utils.config import ImageSearchModel from khoj.utils.config import ImageContent, ImageSearchModel
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
@@ -40,7 +42,7 @@ def initialize_model(search_config: ImageSearchConfig):
model_type=search_config.encoder_type or SentenceTransformer, model_type=search_config.encoder_type or SentenceTransformer,
) )
return encoder return ImageSearchModel(encoder)
def extract_entries(image_directories): def extract_entries(image_directories):
@@ -143,7 +145,9 @@ def extract_metadata(image_name):
return image_processed_metadata return image_processed_metadata
async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf): async def query(
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
):
# Set query to image content if query is of form file:/path/to/file.png # Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True) query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
@@ -158,21 +162,21 @@ async def query(raw_query, count, model: ImageSearchModel, score_threshold: floa
# Now we encode the query (which can either be an image or a text string) # Now we encode the query (which can either be an image or a text string)
with timer("Query Encode Time", logger): with timer("Query Encode Time", logger):
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False) query_embedding = search_model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger): with timer("Search Time", logger):
image_hits = { image_hits = {
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]} result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0] for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
} }
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings. # Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if model.image_metadata_embeddings: if content.image_metadata_embeddings:
with timer("Metadata Search Time", logger): with timer("Metadata Search Time", logger):
metadata_hits = { metadata_hits = {
result["corpus_id"]: result["score"] result["corpus_id"]: result["score"]
for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0] for result in util.semantic_search(query_embedding, content.image_metadata_embeddings, top_k=count)[0]
} }
# Sum metadata, image scores of the highest ranked images # Sum metadata, image scores of the highest ranked images
@@ -239,10 +243,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
return results return results
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel: def setup(config: ImageContentConfig, encoder: BaseEncoder, regenerate: bool) -> ImageContent:
# Initialize Model
encoder = initialize_model(search_config)
# Extract Entries # Extract Entries
absolute_image_files, filtered_image_files = set(), set() absolute_image_files, filtered_image_files = set(), set()
if config.input_directories: if config.input_directories:
@@ -268,4 +269,4 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
use_xmp_metadata=config.use_xmp_metadata, use_xmp_metadata=config.use_xmp_metadata,
) )
return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder) return ImageContent(all_image_files, image_embeddings, image_metadata_embeddings)

View File

@@ -13,7 +13,7 @@ from khoj.search_filter.base_filter import BaseFilter
# Internal Packages # Internal Packages
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
from khoj.utils.config import TextSearchModel from khoj.utils.config import TextContent, TextSearchModel
from khoj.utils.models import BaseEncoder from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
from khoj.utils.jsonl import load_jsonl from khoj.utils.jsonl import load_jsonl
@@ -26,9 +26,6 @@ def initialize_model(search_config: TextSearchConfig):
"Initialize model for semantic search on text" "Initialize model for semantic search on text"
torch.set_num_threads(4) torch.set_num_threads(4)
# Number of entries we want to retrieve with the bi-encoder
top_k = 15
# If model directory is configured # If model directory is configured
if search_config.model_directory: if search_config.model_directory:
# Convert model directory to absolute path # Convert model directory to absolute path
@@ -52,7 +49,7 @@ def initialize_model(search_config: TextSearchConfig):
device=f"{state.device}", device=f"{state.device}",
) )
return bi_encoder, cross_encoder, top_k return TextSearchModel(bi_encoder, cross_encoder)
def extract_entries(jsonl_file) -> List[Entry]: def extract_entries(jsonl_file) -> List[Entry]:
@@ -67,7 +64,7 @@ def compute_embeddings(
new_entries = [] new_entries = []
# Load pre-computed embeddings from file if exists and update them if required # Load pre-computed embeddings from file if exists and update them if required
if embeddings_file.exists() and not regenerate: if embeddings_file.exists() and not regenerate:
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device) corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}") logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
# Encode any new entries in the corpus and update corpus embeddings # Encode any new entries in the corpus and update corpus embeddings
@@ -104,17 +101,18 @@ def compute_embeddings(
async def query( async def query(
raw_query: str, raw_query: str,
model: TextSearchModel, search_model: TextSearchModel,
content: TextContent,
question_embedding: Union[torch.Tensor, None] = None, question_embedding: Union[torch.Tensor, None] = None,
rank_results: bool = False, rank_results: bool = False,
score_threshold: float = -math.inf, score_threshold: float = -math.inf,
dedupe: bool = True, dedupe: bool = True,
) -> Tuple[List[dict], List[Entry]]: ) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query" "Search for entries that answer the query"
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings
# Filter query, entries and embeddings before semantic search # Filter query, entries and embeddings before semantic search
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters) query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters)
# If no entries left after filtering, return empty results # If no entries left after filtering, return empty results
if entries is None or len(entries) == 0: if entries is None or len(entries) == 0:
@@ -127,18 +125,17 @@ async def query(
# Encode the query using the bi-encoder # Encode the query using the bi-encoder
if question_embedding is None: if question_embedding is None:
with timer("Query Encode Time", logger, state.device): with timer("Query Encode Time", logger, state.device):
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding) question_embedding = util.normalize_embeddings(question_embedding)
# Find relevant entries for the query # Find relevant entries for the query
top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus
with timer("Search Time", logger, state.device): with timer("Search Time", logger, state.device):
hits = util.semantic_search( hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0]
question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score
)[0]
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
if rank_results: if rank_results and search_model.cross_encoder:
hits = cross_encoder_score(model.cross_encoder, query, entries, hits) hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits)
# Filter results by score threshold # Filter results by score threshold
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold] hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
@@ -173,13 +170,10 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
def setup( def setup(
text_to_jsonl: Type[TextToJsonl], text_to_jsonl: Type[TextToJsonl],
config: TextConfigBase, config: TextConfigBase,
search_config: TextSearchConfig, bi_encoder: BaseEncoder,
regenerate: bool, regenerate: bool,
filters: List[BaseFilter] = [], filters: List[BaseFilter] = [],
) -> TextSearchModel: ) -> TextContent:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in text files to (compressed) JSONL formatted file # Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
previous_entries = ( previous_entries = (
@@ -192,7 +186,6 @@ def setup(
if is_none_or_empty(entries): if is_none_or_empty(entries):
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()]) config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
raise ValueError(f"No valid entries found in specified files: {config_params}") raise ValueError(f"No valid entries found in specified files: {config_params}")
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
# Compute or Load Embeddings # Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file) config.embeddings_file = resolve_absolute_path(config.embeddings_file)
@@ -203,7 +196,7 @@ def setup(
for filter in filters: for filter in filters:
filter.load(entries, regenerate=regenerate) filter.load(entries, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k) return TextContent(entries, corpus_embeddings, filters)
def apply_filters( def apply_filters(

View File

@@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Union from typing import TYPE_CHECKING, Dict, List, Optional, Union
# External Packages # External Packages
import torch import torch
@@ -30,42 +30,48 @@ class ProcessorType(str, Enum):
Conversation = "conversation" Conversation = "conversation"
@dataclass
class TextContent:
entries: List[Entry]
corpus_embeddings: torch.Tensor
filters: List[BaseFilter]
@dataclass
class ImageContent:
image_names: List[str]
image_embeddings: torch.Tensor
image_metadata_embeddings: torch.Tensor
@dataclass
class TextSearchModel: class TextSearchModel:
def __init__( bi_encoder: BaseEncoder
self, cross_encoder: Optional[CrossEncoder] = None
entries: List[Entry], top_k: Optional[int] = 15
corpus_embeddings: torch.Tensor,
bi_encoder: BaseEncoder,
cross_encoder: CrossEncoder,
filters: List[BaseFilter],
top_k,
):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder
self.cross_encoder = cross_encoder
self.filters = filters
self.top_k = top_k
@dataclass
class ImageSearchModel: class ImageSearchModel:
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder): image_encoder: BaseEncoder
self.image_encoder = image_encoder
self.image_names = image_names
self.image_embeddings = image_embeddings @dataclass
self.image_metadata_embeddings = image_metadata_embeddings class ContentIndex:
self.image_encoder = image_encoder org: Optional[TextContent] = None
markdown: Optional[TextContent] = None
pdf: Optional[TextContent] = None
github: Optional[TextContent] = None
notion: Optional[TextContent] = None
image: Optional[ImageContent] = None
plugins: Optional[Dict[str, TextContent]] = None
@dataclass @dataclass
class SearchModels: class SearchModels:
org_search: Union[TextSearchModel, None] = None text_search: Optional[TextSearchModel] = None
markdown_search: Union[TextSearchModel, None] = None image_search: Optional[ImageSearchModel] = None
pdf_search: Union[TextSearchModel, None] = None plugin_search: Optional[Dict[str, TextSearchModel]] = None
image_search: Union[ImageSearchModel, None] = None
github_search: Union[TextSearchModel, None] = None
notion_search: Union[TextSearchModel, None] = None
plugin_search: Union[Dict[str, TextSearchModel], None] = None
class ConversationProcessorConfigModel: class ConversationProcessorConfigModel:

View File

@@ -20,7 +20,7 @@ from khoj.utils import constants
if TYPE_CHECKING: if TYPE_CHECKING:
# External Packages # External Packages
from sentence_transformers import CrossEncoder from sentence_transformers import SentenceTransformer, CrossEncoder
# Internal Packages # Internal Packages
from khoj.utils.models import BaseEncoder from khoj.utils.models import BaseEncoder
@@ -64,7 +64,9 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict return merged_dict
def load_model(model_name: str, model_type, model_dir=None, device: str = None) -> Union[BaseEncoder, CrossEncoder]: def load_model(
model_name: str, model_type, model_dir=None, device: str = None
) -> Union[BaseEncoder, SentenceTransformer, CrossEncoder]:
"Load model from disk or huggingface" "Load model from disk or huggingface"
# Construct model path # Construct model path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -119,9 +119,9 @@ class AppConfig(ConfigBase):
class FullConfig(ConfigBase): class FullConfig(ConfigBase):
content_type: Optional[ContentConfig] content_type: Optional[ContentConfig] = None
search_type: Optional[SearchConfig] search_type: Optional[SearchConfig] = None
processor: Optional[ProcessorConfig] processor: Optional[ProcessorConfig] = None
app: Optional[AppConfig] = AppConfig(should_log_telemetry=True) app: Optional[AppConfig] = AppConfig(should_log_telemetry=True)

View File

@@ -9,13 +9,14 @@ from pathlib import Path
# Internal Packages # Internal Packages
from khoj.utils import config as utils_config from khoj.utils import config as utils_config
from khoj.utils.config import SearchModels, ProcessorConfigModel from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
from khoj.utils.helpers import LRU from khoj.utils.helpers import LRU
from khoj.utils.rawconfig import FullConfig from khoj.utils.rawconfig import FullConfig
# Application Global State # Application Global State
config = FullConfig() config = FullConfig()
model = SearchModels() search_models = SearchModels()
content_index = ContentIndex()
processor_config = ProcessorConfigModel() processor_config = ProcessorConfigModel()
config_file: Path = None config_file: Path = None
verbose: int = 0 verbose: int = 0