Merge pull request #317 from khoj-ai/reduce-memory-consumption-by-search-model-duplication

Reuse Search Models across Content Types to reduce Memory Consumption

- Memory consumption now only scales with search models used, not with content types. 
  Previously each content type had it's own copy of the search ML models. 
  That'd result in 300+ Mb per enabled text content type

- Split model state into 2 separate state objects, `search_models` and `content_index`. 
  This allows loading text_search and image_search models first
  and then reusing them across all content_types in content_index

- The change should cut down memory utilization quite a bit for most users.
  I see a >50% drop in memory utilization on my Khoj instance. 
  But this will vary for each user based on the amount of content indexed vs number of plugins enabled.

- This change does not solve the RAM utilization scaling with size of the index,
  as the whole content index is still kept in RAM while Khoj is running

Should help with #195, #301 and #303
This commit is contained in:
Debanjum
2023-07-14 19:54:12 -07:00
committed by GitHub
12 changed files with 361 additions and 196 deletions

View File

@@ -20,9 +20,15 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
from khoj.search_type import image_search, text_search
from khoj.utils import constants, state
from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from khoj.utils.config import (
ContentIndex,
SearchType,
SearchModels,
ProcessorConfigModel,
ConversationProcessorConfigModel,
)
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
from khoj.utils.rawconfig import FullConfig, ProcessorConfig
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
@@ -49,11 +55,27 @@ def configure_server(args, required=False):
# Initialize Processor from Config
state.processor_config = configure_processor(args.config.processor)
# Initialize the search type and model from Config
state.search_index_lock.acquire()
state.SearchType = configure_search_types(state.config)
state.model = configure_search(state.model, state.config, args.regenerate)
state.search_index_lock.release()
# Initialize Search Models from Config
try:
state.search_index_lock.acquire()
state.SearchType = configure_search_types(state.config)
state.search_models = configure_search(state.search_models, state.config.search_type)
except Exception as e:
logger.error(f"🚨 Error configuring search models on app load: {e}")
finally:
state.search_index_lock.release()
# Initialize Content from Config
if state.search_models:
try:
state.search_index_lock.acquire()
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, args.regenerate
)
except Exception as e:
logger.error(f"🚨 Error configuring content index on app load: {e}")
finally:
state.search_index_lock.release()
def configure_routes(app):
@@ -72,10 +94,16 @@ if not state.demo:
@schedule.repeat(schedule.every(61).minutes)
def update_search_index():
state.search_index_lock.acquire()
state.model = configure_search(state.model, state.config, regenerate=False)
state.search_index_lock.release()
logger.info("📬 Search index updated via Scheduler")
try:
state.search_index_lock.acquire()
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, regenerate=False
)
logger.info("📬 Content index updated via Scheduler")
except Exception as e:
logger.error(f"🚨 Error updating content index via Scheduler: {e}")
finally:
state.search_index_lock.release()
def configure_search_types(config: FullConfig):
@@ -90,94 +118,116 @@ def configure_search_types(config: FullConfig):
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None):
if config is None or config.content_type is None or config.search_type is None:
logger.warning("🚨 No Content or Search type is configured.")
return
def configure_search(search_models: SearchModels, search_config: SearchConfig) -> Optional[SearchModels]:
# Run Validation Checks
if search_config is None:
logger.warning("🚨 No Search type is configured.")
return None
if search_models is None:
search_models = SearchModels()
if model is None:
model = SearchModels()
# Initialize Search Models
if search_config.asymmetric:
logger.info("🔍 📜 Setting up text search model")
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
if search_config.image:
logger.info("🔍 🌄 Setting up image search model")
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
def configure_content(
content_index: Optional[ContentIndex],
content_config: Optional[ContentConfig],
search_models: SearchModels,
regenerate: bool,
t: Optional[state.SearchType] = None,
) -> Optional[ContentIndex]:
# Run Validation Checks
if content_config is None:
logger.warning("🚨 No Content type is configured.")
return None
if content_index is None:
content_index = ContentIndex()
try:
# Initialize Org Notes Search
if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric:
if (t == state.SearchType.Org or t == None) and content_config.org and search_models.text_search:
logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings
model.org_search = text_search.setup(
content_index.org = text_search.setup(
OrgToJsonl,
config.content_type.org,
search_config=config.search_type.asymmetric,
content_config.org,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Markdown Search
if (
(t == state.SearchType.Markdown or t == None)
and config.content_type.markdown
and config.search_type.asymmetric
):
if (t == state.SearchType.Markdown or t == None) and content_config.markdown and search_models.text_search:
logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup(
content_index.markdown = text_search.setup(
MarkdownToJsonl,
config.content_type.markdown,
search_config=config.search_type.asymmetric,
content_config.markdown,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize PDF Search
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric:
if (t == state.SearchType.Pdf or t == None) and content_config.pdf and search_models.text_search:
logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings
model.pdf_search = text_search.setup(
content_index.pdf = text_search.setup(
PdfToJsonl,
config.content_type.pdf,
search_config=config.search_type.asymmetric,
content_config.pdf,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Image Search
if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image:
if (t == state.SearchType.Image or t == None) and content_config.image and search_models.image_search:
logger.info("🌄 Setting up search for images")
# Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup(
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
)
if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric:
if (t == state.SearchType.Github or t == None) and content_config.github and search_models.text_search:
logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings
model.github_search = text_search.setup(
content_index.github = text_search.setup(
GithubToJsonl,
config.content_type.github,
search_config=config.search_type.asymmetric,
content_config.github,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize External Plugin Search
if (t == None or t in state.SearchType) and config.content_type.plugins:
if (t == None or t in state.SearchType) and content_config.plugins and search_models.text_search:
logger.info("🔌 Setting up search for plugins")
model.plugin_search = {}
for plugin_type, plugin_config in config.content_type.plugins.items():
model.plugin_search[plugin_type] = text_search.setup(
content_index.plugins = {}
for plugin_type, plugin_config in content_config.plugins.items():
content_index.plugins[plugin_type] = text_search.setup(
JsonlToJsonl,
plugin_config,
search_config=config.search_type.asymmetric,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Notion Search
if (t == None or t in state.SearchType) and config.content_type.notion:
if (t == None or t in state.SearchType) and content_config.notion and search_models.text_search:
logger.info("🔌 Setting up search for notion")
model.notion_search = text_search.setup(
content_index.notion = text_search.setup(
NotionToJsonl,
config.content_type.notion,
search_config=config.search_type.asymmetric,
content_config.notion,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
@@ -189,7 +239,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
# Invalidate Query Cache
state.query_cache = LRU()
return model
return content_index
def configure_processor(processor_config: ProcessorConfig):

View File

@@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Header, Request
from sentence_transformers import util
# Internal Packages
from khoj.configure import configure_processor, configure_search
from khoj.configure import configure_content, configure_processor, configure_search
from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
@@ -163,17 +163,17 @@ if not state.demo:
state.config.content_type[content_type] = None
if content_type == "github":
state.model.github_search = None
state.content_index.github = None
elif content_type == "notion":
state.model.notion_search = None
state.content_index.notion = None
elif content_type == "plugins":
state.model.plugin_search = None
state.content_index.plugins = None
elif content_type == "pdf":
state.model.pdf_search = None
state.content_index.pdf = None
elif content_type == "markdown":
state.model.markdown_search = None
state.content_index.markdown = None
elif content_type == "org":
state.model.org_search = None
state.content_index.org = None
try:
save_config_to_file_updated_state()
@@ -280,7 +280,7 @@ def get_config_types():
for search_type in SearchType
if (
search_type.value in configured_content_types
and getattr(state.model, f"{search_type.value}_search") is not None
and getattr(state.content_index, search_type.value) is not None
)
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
or search_type == SearchType.All
@@ -308,7 +308,7 @@ async def search(
if q is None or q == "":
logger.warning(f"No query param (q) passed in API call to initiate search")
return results
if not state.model or not any(state.model.__dict__.values()):
if not state.search_models or not any(state.search_models.__dict__.values()):
logger.warning(f"No search models loaded. Configure a search model before initiating search")
return results
@@ -332,7 +332,7 @@ async def search(
encoded_asymmetric_query = None
if t == SearchType.All or t != SearchType.Image:
text_search_models: List[TextSearchModel] = [
model for model in state.model.__dict__.values() if isinstance(model, TextSearchModel)
model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel)
]
if text_search_models:
with timer("Encoding query took", logger=logger):
@@ -345,13 +345,14 @@ async def search(
)
with concurrent.futures.ThreadPoolExecutor() as executor:
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:
if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search:
# query org-mode notes
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.org_search,
state.search_models.text_search,
state.content_index.org,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@@ -359,13 +360,18 @@ async def search(
)
]
if (t == SearchType.Markdown or t == SearchType.All) and state.model.markdown_search:
if (
(t == SearchType.Markdown or t == SearchType.All)
and state.content_index.markdown
and state.search_models.text_search
):
# query markdown notes
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.markdown_search,
state.search_models.text_search,
state.content_index.markdown,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@@ -373,13 +379,18 @@ async def search(
)
]
if (t == SearchType.Github or t == SearchType.All) and state.model.github_search:
if (
(t == SearchType.Github or t == SearchType.All)
and state.content_index.github
and state.search_models.text_search
):
# query github issues
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.github_search,
state.search_models.text_search,
state.content_index.github,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@@ -387,13 +398,14 @@ async def search(
)
]
if (t == SearchType.Pdf or t == SearchType.All) and state.model.pdf_search:
if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
# query pdf files
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.pdf_search,
state.search_models.text_search,
state.content_index.pdf,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@@ -401,26 +413,38 @@ async def search(
)
]
if (t == SearchType.Image) and state.model.image_search:
if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
# query images
search_futures += [
executor.submit(
image_search.query,
user_query,
results_count,
state.model.image_search,
state.search_models.image_search,
state.content_index.image,
score_threshold=score_threshold,
)
]
if (t == SearchType.All or t in SearchType) and state.model.plugin_search:
if (
(t == SearchType.All or t in SearchType)
and state.content_index.plugins
and state.search_models.plugin_search
):
# query specified plugin type
# Get plugin content, search model for specified search type, or the first one if none specified
plugin_search = state.search_models.plugin_search.get(t.value) or next(
iter(state.search_models.plugin_search.values())
)
plugin_content = state.content_index.plugins.get(t.value) or next(
iter(state.content_index.plugins.values())
)
search_futures += [
executor.submit(
text_search.query,
user_query,
# Get plugin search model for specified search type, or the first one if none specified
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
plugin_search,
plugin_content,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@@ -428,13 +452,18 @@ async def search(
)
]
if (t == SearchType.Notion or t == SearchType.All) and state.model.notion_search:
if (
(t == SearchType.Notion or t == SearchType.All)
and state.content_index.notion
and state.search_models.text_search
):
# query notion pages
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.notion_search,
state.search_models.text_search,
state.content_index.notion,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@@ -445,13 +474,13 @@ async def search(
# Query across each requested content types in parallel
with timer("Query took", logger):
for search_future in concurrent.futures.as_completed(search_futures):
if t == SearchType.Image:
if t == SearchType.Image and state.content_index.image:
hits = await search_future.result()
output_directory = constants.web_directory / "images"
# Collate results
results += image_search.collate_results(
hits,
image_names=state.model.image_search.image_names,
image_names=state.content_index.image.image_names,
output_directory=output_directory,
image_files_url="/static/images",
count=results_count,
@@ -498,7 +527,12 @@ def update(
try:
state.search_index_lock.acquire()
try:
state.model = configure_search(state.model, state.config, regenerate=force or False, t=t)
if state.config and state.config.search_type:
state.search_models = configure_search(state.search_models, state.config.search_type)
if state.search_models:
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t
)
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -12,10 +12,12 @@ from sentence_transformers import SentenceTransformer, util
from PIL import Image
from tqdm import trange
import torch
from khoj.utils import state
# Internal Packages
from khoj.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
from khoj.utils.config import ImageSearchModel
from khoj.utils.config import ImageContent, ImageSearchModel
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
@@ -40,7 +42,7 @@ def initialize_model(search_config: ImageSearchConfig):
model_type=search_config.encoder_type or SentenceTransformer,
)
return encoder
return ImageSearchModel(encoder)
def extract_entries(image_directories):
@@ -143,7 +145,9 @@ def extract_metadata(image_name):
return image_processed_metadata
async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
async def query(
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
):
# Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
@@ -158,21 +162,21 @@ async def query(raw_query, count, model: ImageSearchModel, score_threshold: floa
# Now we encode the query (which can either be an image or a text string)
with timer("Query Encode Time", logger):
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
query_embedding = search_model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger):
image_hits = {
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
}
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if model.image_metadata_embeddings:
if content.image_metadata_embeddings:
with timer("Metadata Search Time", logger):
metadata_hits = {
result["corpus_id"]: result["score"]
for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]
for result in util.semantic_search(query_embedding, content.image_metadata_embeddings, top_k=count)[0]
}
# Sum metadata, image scores of the highest ranked images
@@ -239,10 +243,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
return results
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
# Initialize Model
encoder = initialize_model(search_config)
def setup(config: ImageContentConfig, encoder: BaseEncoder, regenerate: bool) -> ImageContent:
# Extract Entries
absolute_image_files, filtered_image_files = set(), set()
if config.input_directories:
@@ -268,4 +269,4 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
use_xmp_metadata=config.use_xmp_metadata,
)
return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder)
return ImageContent(all_image_files, image_embeddings, image_metadata_embeddings)

View File

@@ -13,7 +13,7 @@ from khoj.search_filter.base_filter import BaseFilter
# Internal Packages
from khoj.utils import state
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
from khoj.utils.config import TextSearchModel
from khoj.utils.config import TextContent, TextSearchModel
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
from khoj.utils.jsonl import load_jsonl
@@ -26,9 +26,6 @@ def initialize_model(search_config: TextSearchConfig):
"Initialize model for semantic search on text"
torch.set_num_threads(4)
# Number of entries we want to retrieve with the bi-encoder
top_k = 15
# If model directory is configured
if search_config.model_directory:
# Convert model directory to absolute path
@@ -52,7 +49,7 @@ def initialize_model(search_config: TextSearchConfig):
device=f"{state.device}",
)
return bi_encoder, cross_encoder, top_k
return TextSearchModel(bi_encoder, cross_encoder)
def extract_entries(jsonl_file) -> List[Entry]:
@@ -67,7 +64,7 @@ def compute_embeddings(
new_entries = []
# Load pre-computed embeddings from file if exists and update them if required
if embeddings_file.exists() and not regenerate:
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
# Encode any new entries in the corpus and update corpus embeddings
@@ -104,17 +101,18 @@ def compute_embeddings(
async def query(
raw_query: str,
model: TextSearchModel,
search_model: TextSearchModel,
content: TextContent,
question_embedding: Union[torch.Tensor, None] = None,
rank_results: bool = False,
score_threshold: float = -math.inf,
dedupe: bool = True,
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings
# Filter query, entries and embeddings before semantic search
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters)
# If no entries left after filtering, return empty results
if entries is None or len(entries) == 0:
@@ -127,18 +125,17 @@ async def query(
# Encode the query using the bi-encoder
if question_embedding is None:
with timer("Query Encode Time", logger, state.device):
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding)
# Find relevant entries for the query
top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus
with timer("Search Time", logger, state.device):
hits = util.semantic_search(
question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score
)[0]
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0]
# Score all retrieved entries using the cross-encoder
if rank_results:
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
if rank_results and search_model.cross_encoder:
hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits)
# Filter results by score threshold
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
@@ -173,13 +170,10 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
def setup(
text_to_jsonl: Type[TextToJsonl],
config: TextConfigBase,
search_config: TextSearchConfig,
bi_encoder: BaseEncoder,
regenerate: bool,
filters: List[BaseFilter] = [],
) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
) -> TextContent:
# Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
previous_entries = (
@@ -192,7 +186,6 @@ def setup(
if is_none_or_empty(entries):
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
raise ValueError(f"No valid entries found in specified files: {config_params}")
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
@@ -203,7 +196,7 @@ def setup(
for filter in filters:
filter.load(entries, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
return TextContent(entries, corpus_embeddings, filters)
def apply_filters(

View File

@@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints
from enum import Enum
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union
# External Packages
import torch
@@ -30,42 +30,48 @@ class ProcessorType(str, Enum):
Conversation = "conversation"
@dataclass
class TextContent:
entries: List[Entry]
corpus_embeddings: torch.Tensor
filters: List[BaseFilter]
@dataclass
class ImageContent:
image_names: List[str]
image_embeddings: torch.Tensor
image_metadata_embeddings: torch.Tensor
@dataclass
class TextSearchModel:
def __init__(
self,
entries: List[Entry],
corpus_embeddings: torch.Tensor,
bi_encoder: BaseEncoder,
cross_encoder: CrossEncoder,
filters: List[BaseFilter],
top_k,
):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder
self.cross_encoder = cross_encoder
self.filters = filters
self.top_k = top_k
bi_encoder: BaseEncoder
cross_encoder: Optional[CrossEncoder] = None
top_k: Optional[int] = 15
@dataclass
class ImageSearchModel:
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
self.image_encoder = image_encoder
self.image_names = image_names
self.image_embeddings = image_embeddings
self.image_metadata_embeddings = image_metadata_embeddings
self.image_encoder = image_encoder
image_encoder: BaseEncoder
@dataclass
class ContentIndex:
org: Optional[TextContent] = None
markdown: Optional[TextContent] = None
pdf: Optional[TextContent] = None
github: Optional[TextContent] = None
notion: Optional[TextContent] = None
image: Optional[ImageContent] = None
plugins: Optional[Dict[str, TextContent]] = None
@dataclass
class SearchModels:
org_search: Union[TextSearchModel, None] = None
markdown_search: Union[TextSearchModel, None] = None
pdf_search: Union[TextSearchModel, None] = None
image_search: Union[ImageSearchModel, None] = None
github_search: Union[TextSearchModel, None] = None
notion_search: Union[TextSearchModel, None] = None
plugin_search: Union[Dict[str, TextSearchModel], None] = None
text_search: Optional[TextSearchModel] = None
image_search: Optional[ImageSearchModel] = None
plugin_search: Optional[Dict[str, TextSearchModel]] = None
class ConversationProcessorConfigModel:

View File

@@ -20,7 +20,7 @@ from khoj.utils import constants
if TYPE_CHECKING:
# External Packages
from sentence_transformers import CrossEncoder
from sentence_transformers import SentenceTransformer, CrossEncoder
# Internal Packages
from khoj.utils.models import BaseEncoder
@@ -64,7 +64,9 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict
def load_model(model_name: str, model_type, model_dir=None, device: str = None) -> Union[BaseEncoder, CrossEncoder]:
def load_model(
model_name: str, model_type, model_dir=None, device: str = None
) -> Union[BaseEncoder, SentenceTransformer, CrossEncoder]:
"Load model from disk or huggingface"
# Construct model path
logger = logging.getLogger(__name__)

View File

@@ -119,9 +119,9 @@ class AppConfig(ConfigBase):
class FullConfig(ConfigBase):
content_type: Optional[ContentConfig]
search_type: Optional[SearchConfig]
processor: Optional[ProcessorConfig]
content_type: Optional[ContentConfig] = None
search_type: Optional[SearchConfig] = None
processor: Optional[ProcessorConfig] = None
app: Optional[AppConfig] = AppConfig(should_log_telemetry=True)

View File

@@ -9,13 +9,14 @@ from pathlib import Path
# Internal Packages
from khoj.utils import config as utils_config
from khoj.utils.config import SearchModels, ProcessorConfigModel
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
from khoj.utils.helpers import LRU
from khoj.utils.rawconfig import FullConfig
# Application Global State
config = FullConfig()
model = SearchModels()
search_models = SearchModels()
content_index = ContentIndex()
processor_config = ProcessorConfigModel()
config_file: Path = None
verbose: int = 0

View File

@@ -10,6 +10,7 @@ from khoj.main import app
from khoj.configure import configure_processor, configure_routes, configure_search_types
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.search_type import image_search, text_search
from khoj.utils.config import ImageContent, SearchModels, TextContent
from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import (
ContentConfig,
@@ -41,35 +42,49 @@ def search_config() -> SearchConfig:
encoder="sentence-transformers/all-MiniLM-L6-v2",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "symmetric/",
encoder_type=None,
)
search_config.asymmetric = TextSearchConfig(
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "asymmetric/",
encoder_type=None,
)
search_config.image = ImageSearchConfig(
encoder="sentence-transformers/clip-ViT-B-32", model_directory=model_dir / "image/"
encoder="sentence-transformers/clip-ViT-B-32",
model_directory=model_dir / "image/",
encoder_type=None,
)
return search_config
@pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_config: SearchConfig):
def search_models(search_config: SearchConfig):
search_models = SearchModels()
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
@pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
content_dir = tmp_path_factory.mktemp("content")
# Generate Image Embeddings from Test Images
content_config = ContentConfig()
content_config.image = ImageContentConfig(
input_filter=None,
input_directories=["tests/data/images"],
embeddings_file=content_dir.joinpath("image_embeddings.pt"),
batch_size=1,
use_xmp_metadata=False,
)
image_search.setup(content_config.image, search_config.image, regenerate=False)
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
# Generate Notes Embeddings from Test Notes
content_config.org = TextContentConfig(
@@ -80,7 +95,9 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
)
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
)
content_config.plugins = {
"plugin1": TextContentConfig(
@@ -106,7 +123,11 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(
JsonlToJsonl, content_config.plugins["plugin1"], search_config.asymmetric, regenerate=False, filters=filters
JsonlToJsonl,
content_config.plugins["plugin1"],
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
return content_config
@@ -157,8 +178,13 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
# Index Markdown Content for Search
filters = [DateFilter(), WordFilter(), FileFilter()]
state.model.markdown_search = text_search.setup(
MarkdownToJsonl, md_content_config.markdown, search_config.asymmetric, regenerate=False, filters=filters
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.content_index.markdown = text_search.setup(
MarkdownToJsonl,
md_content_config.markdown,
state.search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
# Initialize Processor from Config
@@ -175,8 +201,14 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor
state.SearchType = configure_search_types(state.config)
# These lines help us Mock the Search models for these search types
state.model.org_search = {}
state.model.image_search = {}
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image)
state.content_index.org = text_search.setup(
OrgToJsonl, content_config.org, state.search_models.text_search.bi_encoder, regenerate=False
)
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
)
configure_routes(app)
return TestClient(app)

View File

@@ -11,7 +11,8 @@ from fastapi.testclient import TestClient
from khoj.main import app
from khoj.configure import configure_routes, configure_search_types
from khoj.utils import state
from khoj.utils.state import model, config
from khoj.utils.config import SearchModels
from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
@@ -143,7 +144,10 @@ def test_get_configured_types_with_no_content_config():
# ----------------------------------------------------------------------------------------------------
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
query_expected_image_pairs = [
("kitten", "kitten_park.jpg"),
("a horse and dog on a leash", "horse_dog.jpg"),
@@ -166,7 +170,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
# ----------------------------------------------------------------------------------------------------
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.org_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
user_query = quote("How to git install application?")
# Act
@@ -183,8 +190,9 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
def test_notes_search_with_only_filters(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange
filters = [WordFilter(), FileFilter()]
model.org_search = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
)
user_query = quote('+"Emacs" file:"*.org"')
@@ -202,8 +210,9 @@ def test_notes_search_with_only_filters(client, content_config: ContentConfig, s
def test_notes_search_with_include_filter(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange
filters = [WordFilter()]
model.org_search = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search, regenerate=False, filters=filters
)
user_query = quote('How to git install application? +"Emacs"')
@@ -221,8 +230,9 @@ def test_notes_search_with_include_filter(client, content_config: ContentConfig,
def test_notes_search_with_exclude_filter(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange
filters = [WordFilter()]
model.org_search = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
)
user_query = quote('How to git install application? -"clone"')

View File

@@ -5,9 +5,10 @@ from PIL import Image
# External Packages
import pytest
from khoj.utils.config import SearchModels
# Internal Packages
from khoj.utils.state import model
from khoj.utils.state import content_index, search_models
from khoj.utils.constants import web_directory
from khoj.search_type import image_search
from khoj.utils.helpers import resolve_absolute_path
@@ -16,10 +17,12 @@ from khoj.utils.rawconfig import ContentConfig, SearchConfig
# Test
# ----------------------------------------------------------------------------------------------------
def test_image_search_setup(content_config: ContentConfig, search_config: SearchConfig):
def test_image_search_setup(content_config: ContentConfig, search_models: SearchModels):
# Act
# Regenerate image search embeddings during image setup
image_search_model = image_search.setup(content_config.image, search_config.image, regenerate=True)
image_search_model = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=True
)
# Assert
assert len(image_search_model.image_names) == 3
@@ -54,8 +57,11 @@ def test_image_metadata(content_config: ContentConfig):
@pytest.mark.anyio
async def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
output_directory = resolve_absolute_path(web_directory)
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
query_expected_image_pairs = [
("kitten", "kitten_park.jpg"),
("horse and dog in a farm", "horse_dog.jpg"),
@@ -64,11 +70,13 @@ async def test_image_search(content_config: ContentConfig, search_config: Search
# Act
for query, expected_image_name in query_expected_image_pairs:
hits = await image_search.query(query, count=1, model=model.image_search)
hits = await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
results = image_search.collate_results(
hits,
model.image_search.image_names,
content_index.image.image_names,
output_directory=output_directory,
image_files_url="/static/images",
count=1,
@@ -90,7 +98,10 @@ async def test_image_search(content_config: ContentConfig, search_config: Search
@pytest.mark.anyio
async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
max_words_supported = 10
query = " ".join(["hello"] * 100)
truncated_query = " ".join(["hello"] * max_words_supported)
@@ -98,7 +109,9 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc
# Act
try:
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
await image_search.query(query, count=1, model=model.image_search)
await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
# Assert
except RuntimeError as e:
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
@@ -110,8 +123,11 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc
@pytest.mark.anyio
async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
output_directory = resolve_absolute_path(web_directory)
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
image_directory = content_config.image.input_directories[0]
query = f"file:{image_directory.joinpath('kitten_park.jpg')}"
@@ -119,11 +135,13 @@ async def test_image_search_by_filepath(content_config: ContentConfig, search_co
# Act
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
hits = await image_search.query(query, count=1, model=model.image_search)
hits = await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
results = image_search.collate_results(
hits,
model.image_search.image_names,
content_index.image.image_names,
output_directory=output_directory,
image_files_url="/static/images",
count=1,

View File

@@ -5,9 +5,10 @@ import os
# External Packages
import pytest
from khoj.utils.config import SearchModels
# Internal Packages
from khoj.utils.state import model
from khoj.utils.state import content_index, search_models
from khoj.search_type import text_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
@@ -41,10 +42,12 @@ def test_asymmetric_setup_with_empty_file_raises_error(
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchModels):
# Act
# Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
# Assert
assert len(notes_model.entries) == 10
@@ -52,18 +55,18 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
# ----------------------------------------------------------------------------------------------------
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_config: SearchConfig, caplog):
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_models: SearchModels, caplog):
# Arrange
caplog.set_level(logging.INFO, logger="khoj")
# Act
# Generate initial notes embeddings during asymmetric setup
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
initial_logs = caplog.text
caplog.clear() # Clear logs
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
final_logs = caplog.text
# Assert
@@ -75,11 +78,16 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi
@pytest.mark.anyio
async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
query = "How to git install application?"
# Act
hits, entries = await text_search.query(query, model=model.notes_search, rank_results=True)
hits, entries = await text_search.query(
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
)
results = text_search.collate_results(hits, entries, count=1)
@@ -90,7 +98,7 @@ async def test_asymmetric_search(content_config: ContentConfig, search_config: S
# ----------------------------------------------------------------------------------------------------
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig):
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
# Arrange
# Insert org-mode entry with size exceeding max token limit to new org file
max_tokens = 256
@@ -103,7 +111,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# Act
# reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
@@ -113,9 +121,11 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path):
def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
# Arrange
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
@@ -127,12 +137,14 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
# regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
# Act
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
assert len(regenerated_notes_model.entries) == 11
@@ -149,9 +161,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
# ----------------------------------------------------------------------------------------------------
def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path):
def test_incremental_update(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
# Arrange
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
@@ -163,7 +177,9 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
# Act
# update embeddings, entries with the newly added note
content_config.org.input_files = [f"{new_org_file}"]
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
# verify new entry added in updated embeddings, entries
@@ -177,10 +193,12 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
def test_asymmetric_setup_github(content_config: ContentConfig, search_config: SearchConfig):
def test_asymmetric_setup_github(content_config: ContentConfig, search_models: SearchModels):
# Act
# Regenerate github embeddings to test asymmetric setup without caching
github_model = text_search.setup(GithubToJsonl, content_config.github, search_config.asymmetric, regenerate=True)
github_model = text_search.setup(
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
)
# Assert
assert len(github_model.entries) > 1