mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 21:29:11 +00:00
Merge branch 'master' of github.com:debanjum/khoj
This commit is contained in:
@@ -20,9 +20,15 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
|
||||
from khoj.utils.config import (
|
||||
ContentIndex,
|
||||
SearchType,
|
||||
SearchModels,
|
||||
ProcessorConfigModel,
|
||||
ConversationProcessorConfigModel,
|
||||
)
|
||||
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
|
||||
from khoj.utils.rawconfig import FullConfig, ProcessorConfig
|
||||
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
@@ -49,11 +55,27 @@ def configure_server(args, required=False):
|
||||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(args.config.processor)
|
||||
|
||||
# Initialize the search type and model from Config
|
||||
state.search_index_lock.acquire()
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.model = configure_search(state.model, state.config, args.regenerate)
|
||||
state.search_index_lock.release()
|
||||
# Initialize Search Models from Config
|
||||
try:
|
||||
state.search_index_lock.acquire()
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Error configuring search models on app load: {e}")
|
||||
finally:
|
||||
state.search_index_lock.release()
|
||||
|
||||
# Initialize Content from Config
|
||||
if state.search_models:
|
||||
try:
|
||||
state.search_index_lock.acquire()
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, state.search_models, args.regenerate
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Error configuring content index on app load: {e}")
|
||||
finally:
|
||||
state.search_index_lock.release()
|
||||
|
||||
|
||||
def configure_routes(app):
|
||||
@@ -72,10 +94,16 @@ if not state.demo:
|
||||
|
||||
@schedule.repeat(schedule.every(61).minutes)
|
||||
def update_search_index():
|
||||
state.search_index_lock.acquire()
|
||||
state.model = configure_search(state.model, state.config, regenerate=False)
|
||||
state.search_index_lock.release()
|
||||
logger.info("📬 Search index updated via Scheduler")
|
||||
try:
|
||||
state.search_index_lock.acquire()
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, state.search_models, regenerate=False
|
||||
)
|
||||
logger.info("📬 Content index updated via Scheduler")
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Error updating content index via Scheduler: {e}")
|
||||
finally:
|
||||
state.search_index_lock.release()
|
||||
|
||||
|
||||
def configure_search_types(config: FullConfig):
|
||||
@@ -90,94 +118,116 @@ def configure_search_types(config: FullConfig):
|
||||
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
||||
|
||||
|
||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None):
|
||||
if config is None or config.content_type is None or config.search_type is None:
|
||||
logger.warning("🚨 No Content or Search type is configured.")
|
||||
return
|
||||
def configure_search(search_models: SearchModels, search_config: SearchConfig) -> Optional[SearchModels]:
|
||||
# Run Validation Checks
|
||||
if search_config is None:
|
||||
logger.warning("🚨 No Search type is configured.")
|
||||
return None
|
||||
if search_models is None:
|
||||
search_models = SearchModels()
|
||||
|
||||
if model is None:
|
||||
model = SearchModels()
|
||||
# Initialize Search Models
|
||||
if search_config.asymmetric:
|
||||
logger.info("🔍 📜 Setting up text search model")
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
|
||||
if search_config.image:
|
||||
logger.info("🔍 🌄 Setting up image search model")
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
||||
return search_models
|
||||
|
||||
|
||||
def configure_content(
|
||||
content_index: Optional[ContentIndex],
|
||||
content_config: Optional[ContentConfig],
|
||||
search_models: SearchModels,
|
||||
regenerate: bool,
|
||||
t: Optional[state.SearchType] = None,
|
||||
) -> Optional[ContentIndex]:
|
||||
# Run Validation Checks
|
||||
if content_config is None:
|
||||
logger.warning("🚨 No Content type is configured.")
|
||||
return None
|
||||
if content_index is None:
|
||||
content_index = ContentIndex()
|
||||
|
||||
try:
|
||||
# Initialize Org Notes Search
|
||||
if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric:
|
||||
if (t == state.SearchType.Org or t == None) and content_config.org and search_models.text_search:
|
||||
logger.info("🦄 Setting up search for orgmode notes")
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
model.org_search = text_search.setup(
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl,
|
||||
config.content_type.org,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Markdown Search
|
||||
if (
|
||||
(t == state.SearchType.Markdown or t == None)
|
||||
and config.content_type.markdown
|
||||
and config.search_type.asymmetric
|
||||
):
|
||||
if (t == state.SearchType.Markdown or t == None) and content_config.markdown and search_models.text_search:
|
||||
logger.info("💎 Setting up search for markdown notes")
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
model.markdown_search = text_search.setup(
|
||||
content_index.markdown = text_search.setup(
|
||||
MarkdownToJsonl,
|
||||
config.content_type.markdown,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.markdown,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize PDF Search
|
||||
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric:
|
||||
if (t == state.SearchType.Pdf or t == None) and content_config.pdf and search_models.text_search:
|
||||
logger.info("🖨️ Setting up search for pdf")
|
||||
# Extract Entries, Generate PDF Embeddings
|
||||
model.pdf_search = text_search.setup(
|
||||
content_index.pdf = text_search.setup(
|
||||
PdfToJsonl,
|
||||
config.content_type.pdf,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.pdf,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Image Search
|
||||
if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image:
|
||||
if (t == state.SearchType.Image or t == None) and content_config.image and search_models.image_search:
|
||||
logger.info("🌄 Setting up search for images")
|
||||
# Extract Entries, Generate Image Embeddings
|
||||
model.image_search = image_search.setup(
|
||||
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
|
||||
)
|
||||
|
||||
if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric:
|
||||
if (t == state.SearchType.Github or t == None) and content_config.github and search_models.text_search:
|
||||
logger.info("🐙 Setting up search for github")
|
||||
# Extract Entries, Generate Github Embeddings
|
||||
model.github_search = text_search.setup(
|
||||
content_index.github = text_search.setup(
|
||||
GithubToJsonl,
|
||||
config.content_type.github,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.github,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize External Plugin Search
|
||||
if (t == None or t in state.SearchType) and config.content_type.plugins:
|
||||
if (t == None or t in state.SearchType) and content_config.plugins and search_models.text_search:
|
||||
logger.info("🔌 Setting up search for plugins")
|
||||
model.plugin_search = {}
|
||||
for plugin_type, plugin_config in config.content_type.plugins.items():
|
||||
model.plugin_search[plugin_type] = text_search.setup(
|
||||
content_index.plugins = {}
|
||||
for plugin_type, plugin_config in content_config.plugins.items():
|
||||
content_index.plugins[plugin_type] = text_search.setup(
|
||||
JsonlToJsonl,
|
||||
plugin_config,
|
||||
search_config=config.search_type.asymmetric,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Notion Search
|
||||
if (t == None or t in state.SearchType) and config.content_type.notion:
|
||||
if (t == None or t in state.SearchType) and content_config.notion and search_models.text_search:
|
||||
logger.info("🔌 Setting up search for notion")
|
||||
model.notion_search = text_search.setup(
|
||||
content_index.notion = text_search.setup(
|
||||
NotionToJsonl,
|
||||
config.content_type.notion,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.notion,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
@@ -189,7 +239,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
||||
# Invalidate Query Cache
|
||||
state.query_cache = LRU()
|
||||
|
||||
return model
|
||||
return content_index
|
||||
|
||||
|
||||
def configure_processor(processor_config: ProcessorConfig):
|
||||
|
||||
@@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Header, Request
|
||||
from sentence_transformers import util
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_processor, configure_search
|
||||
from khoj.configure import configure_content, configure_processor, configure_search
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
@@ -163,17 +163,17 @@ if not state.demo:
|
||||
state.config.content_type[content_type] = None
|
||||
|
||||
if content_type == "github":
|
||||
state.model.github_search = None
|
||||
state.content_index.github = None
|
||||
elif content_type == "notion":
|
||||
state.model.notion_search = None
|
||||
state.content_index.notion = None
|
||||
elif content_type == "plugins":
|
||||
state.model.plugin_search = None
|
||||
state.content_index.plugins = None
|
||||
elif content_type == "pdf":
|
||||
state.model.pdf_search = None
|
||||
state.content_index.pdf = None
|
||||
elif content_type == "markdown":
|
||||
state.model.markdown_search = None
|
||||
state.content_index.markdown = None
|
||||
elif content_type == "org":
|
||||
state.model.org_search = None
|
||||
state.content_index.org = None
|
||||
|
||||
try:
|
||||
save_config_to_file_updated_state()
|
||||
@@ -280,7 +280,7 @@ def get_config_types():
|
||||
for search_type in SearchType
|
||||
if (
|
||||
search_type.value in configured_content_types
|
||||
and getattr(state.model, f"{search_type.value}_search") is not None
|
||||
and getattr(state.content_index, search_type.value) is not None
|
||||
)
|
||||
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
|
||||
or search_type == SearchType.All
|
||||
@@ -308,7 +308,7 @@ async def search(
|
||||
if q is None or q == "":
|
||||
logger.warning(f"No query param (q) passed in API call to initiate search")
|
||||
return results
|
||||
if not state.model or not any(state.model.__dict__.values()):
|
||||
if not state.search_models or not any(state.search_models.__dict__.values()):
|
||||
logger.warning(f"No search models loaded. Configure a search model before initiating search")
|
||||
return results
|
||||
|
||||
@@ -332,7 +332,7 @@ async def search(
|
||||
encoded_asymmetric_query = None
|
||||
if t == SearchType.All or t != SearchType.Image:
|
||||
text_search_models: List[TextSearchModel] = [
|
||||
model for model in state.model.__dict__.values() if isinstance(model, TextSearchModel)
|
||||
model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel)
|
||||
]
|
||||
if text_search_models:
|
||||
with timer("Encoding query took", logger=logger):
|
||||
@@ -345,13 +345,14 @@ async def search(
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:
|
||||
if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search:
|
||||
# query org-mode notes
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.org_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.org,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
@@ -359,13 +360,18 @@ async def search(
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Markdown or t == SearchType.All) and state.model.markdown_search:
|
||||
if (
|
||||
(t == SearchType.Markdown or t == SearchType.All)
|
||||
and state.content_index.markdown
|
||||
and state.search_models.text_search
|
||||
):
|
||||
# query markdown notes
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.markdown_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.markdown,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
@@ -373,13 +379,18 @@ async def search(
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Github or t == SearchType.All) and state.model.github_search:
|
||||
if (
|
||||
(t == SearchType.Github or t == SearchType.All)
|
||||
and state.content_index.github
|
||||
and state.search_models.text_search
|
||||
):
|
||||
# query github issues
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.github_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.github,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
@@ -387,13 +398,14 @@ async def search(
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Pdf or t == SearchType.All) and state.model.pdf_search:
|
||||
if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
|
||||
# query pdf files
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.pdf_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.pdf,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
@@ -401,26 +413,38 @@ async def search(
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Image) and state.model.image_search:
|
||||
if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
|
||||
# query images
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
image_search.query,
|
||||
user_query,
|
||||
results_count,
|
||||
state.model.image_search,
|
||||
state.search_models.image_search,
|
||||
state.content_index.image,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.All or t in SearchType) and state.model.plugin_search:
|
||||
if (
|
||||
(t == SearchType.All or t in SearchType)
|
||||
and state.content_index.plugins
|
||||
and state.search_models.plugin_search
|
||||
):
|
||||
# query specified plugin type
|
||||
# Get plugin content, search model for specified search type, or the first one if none specified
|
||||
plugin_search = state.search_models.plugin_search.get(t.value) or next(
|
||||
iter(state.search_models.plugin_search.values())
|
||||
)
|
||||
plugin_content = state.content_index.plugins.get(t.value) or next(
|
||||
iter(state.content_index.plugins.values())
|
||||
)
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
# Get plugin search model for specified search type, or the first one if none specified
|
||||
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
||||
plugin_search,
|
||||
plugin_content,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
@@ -428,13 +452,18 @@ async def search(
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Notion or t == SearchType.All) and state.model.notion_search:
|
||||
if (
|
||||
(t == SearchType.Notion or t == SearchType.All)
|
||||
and state.content_index.notion
|
||||
and state.search_models.text_search
|
||||
):
|
||||
# query notion pages
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.notion_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.notion,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
@@ -445,13 +474,13 @@ async def search(
|
||||
# Query across each requested content types in parallel
|
||||
with timer("Query took", logger):
|
||||
for search_future in concurrent.futures.as_completed(search_futures):
|
||||
if t == SearchType.Image:
|
||||
if t == SearchType.Image and state.content_index.image:
|
||||
hits = await search_future.result()
|
||||
output_directory = constants.web_directory / "images"
|
||||
# Collate results
|
||||
results += image_search.collate_results(
|
||||
hits,
|
||||
image_names=state.model.image_search.image_names,
|
||||
image_names=state.content_index.image.image_names,
|
||||
output_directory=output_directory,
|
||||
image_files_url="/static/images",
|
||||
count=results_count,
|
||||
@@ -498,7 +527,12 @@ def update(
|
||||
try:
|
||||
state.search_index_lock.acquire()
|
||||
try:
|
||||
state.model = configure_search(state.model, state.config, regenerate=force or False, t=t)
|
||||
if state.config and state.config.search_type:
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
if state.search_models:
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -12,10 +12,12 @@ from sentence_transformers import SentenceTransformer, util
|
||||
from PIL import Image
|
||||
from tqdm import trange
|
||||
import torch
|
||||
from khoj.utils import state
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
|
||||
from khoj.utils.config import ImageSearchModel
|
||||
from khoj.utils.config import ImageContent, ImageSearchModel
|
||||
from khoj.utils.models import BaseEncoder
|
||||
from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
|
||||
|
||||
|
||||
@@ -40,7 +42,7 @@ def initialize_model(search_config: ImageSearchConfig):
|
||||
model_type=search_config.encoder_type or SentenceTransformer,
|
||||
)
|
||||
|
||||
return encoder
|
||||
return ImageSearchModel(encoder)
|
||||
|
||||
|
||||
def extract_entries(image_directories):
|
||||
@@ -143,7 +145,9 @@ def extract_metadata(image_name):
|
||||
return image_processed_metadata
|
||||
|
||||
|
||||
async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
||||
async def query(
|
||||
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
|
||||
):
|
||||
# Set query to image content if query is of form file:/path/to/file.png
|
||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
||||
@@ -158,21 +162,21 @@ async def query(raw_query, count, model: ImageSearchModel, score_threshold: floa
|
||||
|
||||
# Now we encode the query (which can either be an image or a text string)
|
||||
with timer("Query Encode Time", logger):
|
||||
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||
query_embedding = search_model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||
with timer("Search Time", logger):
|
||||
image_hits = {
|
||||
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
|
||||
for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]
|
||||
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
|
||||
}
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
||||
if model.image_metadata_embeddings:
|
||||
if content.image_metadata_embeddings:
|
||||
with timer("Metadata Search Time", logger):
|
||||
metadata_hits = {
|
||||
result["corpus_id"]: result["score"]
|
||||
for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]
|
||||
for result in util.semantic_search(query_embedding, content.image_metadata_embeddings, top_k=count)[0]
|
||||
}
|
||||
|
||||
# Sum metadata, image scores of the highest ranked images
|
||||
@@ -239,10 +243,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
||||
return results
|
||||
|
||||
|
||||
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
|
||||
# Initialize Model
|
||||
encoder = initialize_model(search_config)
|
||||
|
||||
def setup(config: ImageContentConfig, encoder: BaseEncoder, regenerate: bool) -> ImageContent:
|
||||
# Extract Entries
|
||||
absolute_image_files, filtered_image_files = set(), set()
|
||||
if config.input_directories:
|
||||
@@ -268,4 +269,4 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
|
||||
use_xmp_metadata=config.use_xmp_metadata,
|
||||
)
|
||||
|
||||
return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder)
|
||||
return ImageContent(all_image_files, image_embeddings, image_metadata_embeddings)
|
||||
|
||||
@@ -13,7 +13,7 @@ from khoj.search_filter.base_filter import BaseFilter
|
||||
# Internal Packages
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
|
||||
from khoj.utils.config import TextSearchModel
|
||||
from khoj.utils.config import TextContent, TextSearchModel
|
||||
from khoj.utils.models import BaseEncoder
|
||||
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
|
||||
from khoj.utils.jsonl import load_jsonl
|
||||
@@ -26,9 +26,6 @@ def initialize_model(search_config: TextSearchConfig):
|
||||
"Initialize model for semantic search on text"
|
||||
torch.set_num_threads(4)
|
||||
|
||||
# Number of entries we want to retrieve with the bi-encoder
|
||||
top_k = 15
|
||||
|
||||
# If model directory is configured
|
||||
if search_config.model_directory:
|
||||
# Convert model directory to absolute path
|
||||
@@ -52,7 +49,7 @@ def initialize_model(search_config: TextSearchConfig):
|
||||
device=f"{state.device}",
|
||||
)
|
||||
|
||||
return bi_encoder, cross_encoder, top_k
|
||||
return TextSearchModel(bi_encoder, cross_encoder)
|
||||
|
||||
|
||||
def extract_entries(jsonl_file) -> List[Entry]:
|
||||
@@ -67,7 +64,7 @@ def compute_embeddings(
|
||||
new_entries = []
|
||||
# Load pre-computed embeddings from file if exists and update them if required
|
||||
if embeddings_file.exists() and not regenerate:
|
||||
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
|
||||
|
||||
# Encode any new entries in the corpus and update corpus embeddings
|
||||
@@ -104,17 +101,18 @@ def compute_embeddings(
|
||||
|
||||
async def query(
|
||||
raw_query: str,
|
||||
model: TextSearchModel,
|
||||
search_model: TextSearchModel,
|
||||
content: TextContent,
|
||||
question_embedding: Union[torch.Tensor, None] = None,
|
||||
rank_results: bool = False,
|
||||
score_threshold: float = -math.inf,
|
||||
dedupe: bool = True,
|
||||
) -> Tuple[List[dict], List[Entry]]:
|
||||
"Search for entries that answer the query"
|
||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||
query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings
|
||||
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
|
||||
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters)
|
||||
|
||||
# If no entries left after filtering, return empty results
|
||||
if entries is None or len(entries) == 0:
|
||||
@@ -127,18 +125,17 @@ async def query(
|
||||
# Encode the query using the bi-encoder
|
||||
if question_embedding is None:
|
||||
with timer("Query Encode Time", logger, state.device):
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
|
||||
# Find relevant entries for the query
|
||||
top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus
|
||||
with timer("Search Time", logger, state.device):
|
||||
hits = util.semantic_search(
|
||||
question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score
|
||||
)[0]
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0]
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results:
|
||||
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
|
||||
if rank_results and search_model.cross_encoder:
|
||||
hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits)
|
||||
|
||||
# Filter results by score threshold
|
||||
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
|
||||
@@ -173,13 +170,10 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
|
||||
def setup(
|
||||
text_to_jsonl: Type[TextToJsonl],
|
||||
config: TextConfigBase,
|
||||
search_config: TextSearchConfig,
|
||||
bi_encoder: BaseEncoder,
|
||||
regenerate: bool,
|
||||
filters: List[BaseFilter] = [],
|
||||
) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
) -> TextContent:
|
||||
# Map notes in text files to (compressed) JSONL formatted file
|
||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||
previous_entries = (
|
||||
@@ -192,7 +186,6 @@ def setup(
|
||||
if is_none_or_empty(entries):
|
||||
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
|
||||
raise ValueError(f"No valid entries found in specified files: {config_params}")
|
||||
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
|
||||
|
||||
# Compute or Load Embeddings
|
||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
@@ -203,7 +196,7 @@ def setup(
|
||||
for filter in filters:
|
||||
filter.load(entries, regenerate=regenerate)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
||||
return TextContent(entries, corpus_embeddings, filters)
|
||||
|
||||
|
||||
def apply_filters(
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
@@ -30,42 +30,48 @@ class ProcessorType(str, Enum):
|
||||
Conversation = "conversation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextContent:
|
||||
entries: List[Entry]
|
||||
corpus_embeddings: torch.Tensor
|
||||
filters: List[BaseFilter]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageContent:
|
||||
image_names: List[str]
|
||||
image_embeddings: torch.Tensor
|
||||
image_metadata_embeddings: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextSearchModel:
|
||||
def __init__(
|
||||
self,
|
||||
entries: List[Entry],
|
||||
corpus_embeddings: torch.Tensor,
|
||||
bi_encoder: BaseEncoder,
|
||||
cross_encoder: CrossEncoder,
|
||||
filters: List[BaseFilter],
|
||||
top_k,
|
||||
):
|
||||
self.entries = entries
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
self.bi_encoder = bi_encoder
|
||||
self.cross_encoder = cross_encoder
|
||||
self.filters = filters
|
||||
self.top_k = top_k
|
||||
bi_encoder: BaseEncoder
|
||||
cross_encoder: Optional[CrossEncoder] = None
|
||||
top_k: Optional[int] = 15
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageSearchModel:
|
||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
|
||||
self.image_encoder = image_encoder
|
||||
self.image_names = image_names
|
||||
self.image_embeddings = image_embeddings
|
||||
self.image_metadata_embeddings = image_metadata_embeddings
|
||||
self.image_encoder = image_encoder
|
||||
image_encoder: BaseEncoder
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentIndex:
|
||||
org: Optional[TextContent] = None
|
||||
markdown: Optional[TextContent] = None
|
||||
pdf: Optional[TextContent] = None
|
||||
github: Optional[TextContent] = None
|
||||
notion: Optional[TextContent] = None
|
||||
image: Optional[ImageContent] = None
|
||||
plugins: Optional[Dict[str, TextContent]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchModels:
|
||||
org_search: Union[TextSearchModel, None] = None
|
||||
markdown_search: Union[TextSearchModel, None] = None
|
||||
pdf_search: Union[TextSearchModel, None] = None
|
||||
image_search: Union[ImageSearchModel, None] = None
|
||||
github_search: Union[TextSearchModel, None] = None
|
||||
notion_search: Union[TextSearchModel, None] = None
|
||||
plugin_search: Union[Dict[str, TextSearchModel], None] = None
|
||||
text_search: Optional[TextSearchModel] = None
|
||||
image_search: Optional[ImageSearchModel] = None
|
||||
plugin_search: Optional[Dict[str, TextSearchModel]] = None
|
||||
|
||||
|
||||
class ConversationProcessorConfigModel:
|
||||
|
||||
@@ -20,7 +20,7 @@ from khoj.utils import constants
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# External Packages
|
||||
from sentence_transformers import CrossEncoder
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.models import BaseEncoder
|
||||
@@ -64,7 +64,9 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
|
||||
return merged_dict
|
||||
|
||||
|
||||
def load_model(model_name: str, model_type, model_dir=None, device: str = None) -> Union[BaseEncoder, CrossEncoder]:
|
||||
def load_model(
|
||||
model_name: str, model_type, model_dir=None, device: str = None
|
||||
) -> Union[BaseEncoder, SentenceTransformer, CrossEncoder]:
|
||||
"Load model from disk or huggingface"
|
||||
# Construct model path
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -119,9 +119,9 @@ class AppConfig(ConfigBase):
|
||||
|
||||
|
||||
class FullConfig(ConfigBase):
|
||||
content_type: Optional[ContentConfig]
|
||||
search_type: Optional[SearchConfig]
|
||||
processor: Optional[ProcessorConfig]
|
||||
content_type: Optional[ContentConfig] = None
|
||||
search_type: Optional[SearchConfig] = None
|
||||
processor: Optional[ProcessorConfig] = None
|
||||
app: Optional[AppConfig] = AppConfig(should_log_telemetry=True)
|
||||
|
||||
|
||||
|
||||
@@ -9,13 +9,14 @@ from pathlib import Path
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils import config as utils_config
|
||||
from khoj.utils.config import SearchModels, ProcessorConfigModel
|
||||
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
|
||||
from khoj.utils.helpers import LRU
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
|
||||
# Application Global State
|
||||
config = FullConfig()
|
||||
model = SearchModels()
|
||||
search_models = SearchModels()
|
||||
content_index = ContentIndex()
|
||||
processor_config = ProcessorConfigModel()
|
||||
config_file: Path = None
|
||||
verbose: int = 0
|
||||
|
||||
Reference in New Issue
Block a user