mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Pass default value matching argument types expected by text_search methods
This commit is contained in:
@@ -24,6 +24,7 @@ from khoj.search_filter.word_filter import WordFilter
|
|||||||
from khoj.utils.helpers import log_telemetry, timer
|
from khoj.utils.helpers import log_telemetry, timer
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
FullConfig,
|
FullConfig,
|
||||||
|
ProcessorConfig,
|
||||||
SearchResponse,
|
SearchResponse,
|
||||||
TextContentConfig,
|
TextContentConfig,
|
||||||
ConversationProcessorConfig,
|
ConversationProcessorConfig,
|
||||||
@@ -101,7 +102,10 @@ async def set_content_config_data(content_type: str, updated_config: TextContent
|
|||||||
|
|
||||||
@api.post("/config/data/processor/conversation", status_code=200)
|
@api.post("/config/data/processor/conversation", status_code=200)
|
||||||
async def set_processor_conversation_config_data(updated_config: ConversationProcessorConfig):
|
async def set_processor_conversation_config_data(updated_config: ConversationProcessorConfig):
|
||||||
state.config.processor.conversation = updated_config
|
if state.config.processor is None:
|
||||||
|
state.config.processor = ProcessorConfig(conversation=updated_config)
|
||||||
|
else:
|
||||||
|
state.config.processor.conversation = updated_config
|
||||||
try:
|
try:
|
||||||
save_config_to_file_updated_state()
|
save_config_to_file_updated_state()
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
@@ -139,6 +143,7 @@ def search(
|
|||||||
return state.query_cache[query_cache_key]
|
return state.query_cache[query_cache_key]
|
||||||
|
|
||||||
# Encode query with filter terms removed
|
# Encode query with filter terms removed
|
||||||
|
defiltered_query = user_query
|
||||||
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
||||||
defiltered_query = filter.defilter(user_query)
|
defiltered_query = filter.defilter(user_query)
|
||||||
|
|
||||||
@@ -162,9 +167,9 @@ def search(
|
|||||||
user_query,
|
user_query,
|
||||||
state.model.org_search,
|
state.model.org_search,
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
rank_results=r,
|
rank_results=r or False,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
dedupe=dedupe,
|
dedupe=dedupe or True,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -176,9 +181,9 @@ def search(
|
|||||||
user_query,
|
user_query,
|
||||||
state.model.markdown_search,
|
state.model.markdown_search,
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
rank_results=r,
|
rank_results=r or False,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
dedupe=dedupe,
|
dedupe=dedupe or True,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -190,9 +195,9 @@ def search(
|
|||||||
user_query,
|
user_query,
|
||||||
state.model.pdf_search,
|
state.model.pdf_search,
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
rank_results=r,
|
rank_results=r or False,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
dedupe=dedupe,
|
dedupe=dedupe or True,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -203,9 +208,9 @@ def search(
|
|||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
state.model.ledger_search,
|
state.model.ledger_search,
|
||||||
rank_results=r,
|
rank_results=r or False,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
dedupe=dedupe,
|
dedupe=dedupe or True,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -217,9 +222,9 @@ def search(
|
|||||||
user_query,
|
user_query,
|
||||||
state.model.music_search,
|
state.model.music_search,
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
rank_results=r,
|
rank_results=r or False,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
dedupe=dedupe,
|
dedupe=dedupe or True,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -237,16 +242,16 @@ def search(
|
|||||||
|
|
||||||
if (t is None or t in SearchType) and state.model.plugin_search:
|
if (t is None or t in SearchType) and state.model.plugin_search:
|
||||||
# query specified plugin type
|
# query specified plugin type
|
||||||
search_future[t] += [
|
search_futures[t] += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
# Get plugin search model for specified search type, or the first one if none specified
|
# Get plugin search model for specified search type, or the first one if none specified
|
||||||
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
rank_results=r,
|
rank_results=r or False,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
dedupe=dedupe,
|
dedupe=dedupe or True,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -262,12 +267,12 @@ def search(
|
|||||||
image_names=state.model.image_search.image_names,
|
image_names=state.model.image_search.image_names,
|
||||||
output_directory=output_directory,
|
output_directory=output_directory,
|
||||||
image_files_url="/static/images",
|
image_files_url="/static/images",
|
||||||
count=results_count,
|
count=results_count or 5,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hits, entries = search_future.result()
|
hits, entries = search_future.result()
|
||||||
# Collate results
|
# Collate results
|
||||||
results += text_search.collate_results(hits, entries, results_count)
|
results += text_search.collate_results(hits, entries, results_count or 5)
|
||||||
|
|
||||||
# Sort results across all content types
|
# Sort results across all content types
|
||||||
results.sort(key=lambda x: float(x.score), reverse=True)
|
results.sort(key=lambda x: float(x.score), reverse=True)
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ def compute_embeddings(
|
|||||||
def query(
|
def query(
|
||||||
raw_query: str,
|
raw_query: str,
|
||||||
model: TextSearchModel,
|
model: TextSearchModel,
|
||||||
question_embedding: torch.Tensor = None,
|
question_embedding: torch.Tensor | None = None,
|
||||||
rank_results: bool = False,
|
rank_results: bool = False,
|
||||||
score_threshold: float = -math.inf,
|
score_threshold: float = -math.inf,
|
||||||
dedupe: bool = True,
|
dedupe: bool = True,
|
||||||
|
|||||||
Reference in New Issue
Block a user