mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Resolve merge conflicts in dependency imports
This commit is contained in:
@@ -240,10 +240,18 @@ class ConversationAdapters:
|
||||
def get_openai_conversation_config():
|
||||
return OpenAIProcessorConversationConfig.objects.filter().first()
|
||||
|
||||
@staticmethod
|
||||
async def aget_openai_conversation_config():
|
||||
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
|
||||
|
||||
@staticmethod
|
||||
def get_offline_chat_conversation_config():
|
||||
return OfflineChatProcessorConversationConfig.objects.filter().first()
|
||||
|
||||
@staticmethod
|
||||
async def aget_offline_chat_conversation_config():
|
||||
return await OfflineChatProcessorConversationConfig.objects.filter().afirst()
|
||||
|
||||
@staticmethod
|
||||
def has_valid_offline_conversation_config():
|
||||
return OfflineChatProcessorConversationConfig.objects.filter(enabled=True).exists()
|
||||
@@ -267,10 +275,21 @@ class ConversationAdapters:
|
||||
return None
|
||||
return config.setting
|
||||
|
||||
@staticmethod
|
||||
async def aget_conversation_config(user: KhojUser):
|
||||
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
if not config:
|
||||
return None
|
||||
return config.setting
|
||||
|
||||
@staticmethod
|
||||
def get_default_conversation_config():
|
||||
return ChatModelOptions.objects.filter().first()
|
||||
|
||||
@staticmethod
|
||||
async def aget_default_conversation_config():
|
||||
return await ChatModelOptions.objects.filter().afirst()
|
||||
|
||||
@staticmethod
|
||||
def save_conversation(user: KhojUser, conversation_log: dict):
|
||||
conversation = Conversation.objects.filter(user=user)
|
||||
@@ -320,10 +339,6 @@ class ConversationAdapters:
|
||||
async def get_openai_chat_config():
|
||||
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
|
||||
|
||||
@staticmethod
|
||||
async def aget_default_conversation_config():
|
||||
return await ChatModelOptions.objects.filter().afirst()
|
||||
|
||||
|
||||
class EntryAdapters:
|
||||
word_filer = WordFilter()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "Khoj",
|
||||
"version": "0.14.0",
|
||||
"version": "1.0.0",
|
||||
"description": "An AI copilot for your Second Brain",
|
||||
"author": "Saba Imran, Debanjum Singh Solanky <team@khoj.dev>",
|
||||
"license": "GPL-3.0-or-later",
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
;; Saba Imran <saba@khoj.dev>
|
||||
;; Description: An AI copilot for your Second Brain
|
||||
;; Keywords: search, chat, org-mode, outlines, markdown, pdf, image
|
||||
;; Version: 0.14.0
|
||||
;; Version: 1.0.0
|
||||
;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1"))
|
||||
;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs
|
||||
|
||||
@@ -63,7 +63,7 @@
|
||||
;; Khoj Static Configuration
|
||||
;; -------------------------
|
||||
|
||||
(defcustom khoj-server-url "http://localhost:42110"
|
||||
(defcustom khoj-server-url "https://app.khoj.dev"
|
||||
"Location of Khoj API server."
|
||||
:group 'khoj
|
||||
:type 'string)
|
||||
@@ -94,7 +94,7 @@
|
||||
:type 'number)
|
||||
|
||||
(defcustom khoj-api-key nil
|
||||
"API Key to Khoj server."
|
||||
"API Key to your Khoj. Default at https://app.khoj.dev/config#clients."
|
||||
:group 'khoj
|
||||
:type 'string)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"id": "khoj",
|
||||
"name": "Khoj",
|
||||
"version": "0.14.0",
|
||||
"version": "1.0.0",
|
||||
"minAppVersion": "0.15.0",
|
||||
"description": "An AI copilot for your Second Brain",
|
||||
"author": "Khoj Inc.",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "Khoj",
|
||||
"version": "0.14.0",
|
||||
"version": "1.0.0",
|
||||
"description": "An AI copilot for your Second Brain",
|
||||
"author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>",
|
||||
"license": "GPL-3.0-or-later",
|
||||
|
||||
@@ -75,7 +75,7 @@ export default class Khoj extends Plugin {
|
||||
|
||||
if (this.settings.khojUrl === "https://app.khoj.dev") {
|
||||
if (this.settings.khojApiKey === "") {
|
||||
new Notice(`❗️Khoj API key is not configured. Please visit https://app.khoj.dev to get an API key.`);
|
||||
new Notice(`❗️Khoj API key is not configured. Please visit https://app.khoj.dev/config#clients to get an API key.`);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ export interface KhojSetting {
|
||||
|
||||
export const DEFAULT_SETTINGS: KhojSetting = {
|
||||
resultsCount: 6,
|
||||
khojUrl: 'http://127.0.0.1:42110',
|
||||
khojUrl: 'https://app.khoj.dev',
|
||||
khojApiKey: '',
|
||||
connectedToBackend: false,
|
||||
autoConfigure: true,
|
||||
|
||||
@@ -26,5 +26,6 @@
|
||||
"0.12.2": "0.15.0",
|
||||
"0.12.3": "0.15.0",
|
||||
"0.13.0": "0.15.0",
|
||||
"0.14.0": "0.15.0"
|
||||
"0.14.0": "0.15.0",
|
||||
"1.0.0": "0.15.0"
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
{% block content %}
|
||||
|
||||
<div class="page">
|
||||
<div class="section">
|
||||
<div id="content" class="section">
|
||||
<h2 class="section-title">Content</h2>
|
||||
<div class="section-cards">
|
||||
<div class="card">
|
||||
@@ -118,7 +118,7 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section">
|
||||
<div id ="features" class="section">
|
||||
<h2 class="section-title">Features</h2>
|
||||
<div id="features-hint-text"></div>
|
||||
<div class="section-cards">
|
||||
@@ -144,9 +144,9 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section">
|
||||
<div id="clients" class="section">
|
||||
<h2 class="section-title">Clients</h2>
|
||||
<div class="api-settings">
|
||||
<div id="clients-api" class="api-settings">
|
||||
<div class="card-title-row">
|
||||
<img class="card-icon" src="/static/assets/icons/key.svg" alt="API Key">
|
||||
<h3 class="card-title">API Keys</h3>
|
||||
@@ -172,7 +172,7 @@
|
||||
</div>
|
||||
</div>
|
||||
{% if billing_enabled %}
|
||||
<div class="section">
|
||||
<div id="billing" class="section">
|
||||
<h2 class="section-title">Billing</h2>
|
||||
<div class="section-cards">
|
||||
<div class="card">
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Standard Packages
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
@@ -31,6 +32,10 @@ def extract_questions(
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
|
||||
def _valid_question(question: str):
|
||||
return not is_none_or_empty(question) and question != "[]"
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = "".join(
|
||||
[
|
||||
@@ -70,7 +75,7 @@ def extract_questions(
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
try:
|
||||
questions = (
|
||||
split_questions = (
|
||||
response.content.strip(empty_escape_sequences)
|
||||
.replace("['", '["')
|
||||
.replace("']", '"]')
|
||||
@@ -79,9 +84,18 @@ def extract_questions(
|
||||
.replace('"]', "")
|
||||
.split('", "')
|
||||
)
|
||||
questions = []
|
||||
|
||||
for question in split_questions:
|
||||
if question not in questions and _valid_question(question):
|
||||
questions.append(question)
|
||||
|
||||
if is_none_or_empty(questions):
|
||||
raise ValueError("GPT returned empty JSON")
|
||||
except:
|
||||
logger.warning(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||
questions = [text]
|
||||
|
||||
logger.debug(f"Extracted Questions by GPT: {questions}")
|
||||
return questions
|
||||
|
||||
|
||||
@@ -154,17 +154,20 @@ def truncate_messages(
|
||||
)
|
||||
|
||||
system_message = messages.pop()
|
||||
assert type(system_message.content) == str
|
||||
system_message_tokens = len(encoder.encode(system_message.content))
|
||||
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
||||
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
|
||||
messages.pop()
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||
assert type(system_message.content) == str
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
||||
|
||||
# Truncate current message if still over max supported prompt size by model
|
||||
if (tokens + system_message_tokens) > max_prompt_size:
|
||||
current_message = "\n".join(messages[0].content.split("\n")[:-1])
|
||||
original_question = "\n".join(messages[0].content.split("\n")[-1:])
|
||||
assert type(system_message.content) == str
|
||||
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
|
||||
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
|
||||
original_question_tokens = len(encoder.encode(original_question))
|
||||
remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens
|
||||
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
||||
|
||||
@@ -31,6 +31,7 @@ from khoj.utils import state, constants
|
||||
from khoj.utils.helpers import AsyncIteratorWrapper, get_device
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from khoj.routers.helpers import (
|
||||
CommonQueryParams,
|
||||
get_conversation_command,
|
||||
validate_conversation_config,
|
||||
agenerate_chat_response,
|
||||
@@ -55,6 +56,7 @@ from database.models import (
|
||||
Entry as DbEntry,
|
||||
GithubConfig,
|
||||
NotionConfig,
|
||||
ChatModelOptions,
|
||||
)
|
||||
|
||||
|
||||
@@ -122,7 +124,7 @@ async def map_config_to_db(config: FullConfig, user: KhojUser):
|
||||
def _initialize_config():
|
||||
if state.config is None:
|
||||
state.config = FullConfig()
|
||||
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
||||
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
|
||||
|
||||
|
||||
@api.get("/config/data", response_model=FullConfig)
|
||||
@@ -355,15 +357,12 @@ def get_config_types(
|
||||
async def search(
|
||||
q: str,
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
n: Optional[int] = 5,
|
||||
t: Optional[SearchType] = SearchType.All,
|
||||
r: Optional[bool] = False,
|
||||
max_distance: Optional[Union[float, None]] = None,
|
||||
dedupe: Optional[bool] = True,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
):
|
||||
user = request.user.object
|
||||
start_time = time.time()
|
||||
@@ -467,10 +466,7 @@ async def search(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="search",
|
||||
client=client,
|
||||
user_agent=user_agent,
|
||||
referer=referer,
|
||||
host=host,
|
||||
**common.__dict__,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
@@ -483,12 +479,9 @@ async def search(
|
||||
@requires(["authenticated"])
|
||||
def update(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
t: Optional[SearchType] = None,
|
||||
force: Optional[bool] = False,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
):
|
||||
user = request.user.object
|
||||
if not state.config:
|
||||
@@ -514,10 +507,7 @@ def update(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="update",
|
||||
client=client,
|
||||
user_agent=user_agent,
|
||||
referer=referer,
|
||||
host=host,
|
||||
**common.__dict__,
|
||||
)
|
||||
|
||||
return {"status": "ok", "message": "khoj reloaded"}
|
||||
@@ -527,10 +517,7 @@ def update(
|
||||
@requires(["authenticated"])
|
||||
def chat_history(
|
||||
request: Request,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
common: CommonQueryParams,
|
||||
):
|
||||
user = request.user.object
|
||||
validate_conversation_config()
|
||||
@@ -542,10 +529,7 @@ def chat_history(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
client=client,
|
||||
user_agent=user_agent,
|
||||
referer=referer,
|
||||
host=host,
|
||||
**common.__dict__,
|
||||
)
|
||||
|
||||
return {"status": "ok", "response": meta_log.get("chat", [])}
|
||||
@@ -555,10 +539,7 @@ def chat_history(
|
||||
@requires(["authenticated"])
|
||||
async def chat_options(
|
||||
request: Request,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
common: CommonQueryParams,
|
||||
) -> Response:
|
||||
cmd_options = {}
|
||||
for cmd in ConversationCommand:
|
||||
@@ -568,10 +549,7 @@ async def chat_options(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat_options",
|
||||
client=client,
|
||||
user_agent=user_agent,
|
||||
referer=referer,
|
||||
host=host,
|
||||
**common.__dict__,
|
||||
)
|
||||
return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200)
|
||||
|
||||
@@ -580,14 +558,11 @@ async def chat_options(
|
||||
@requires(["authenticated"])
|
||||
async def chat(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
d: Optional[float] = 0.18,
|
||||
client: Optional[str] = None,
|
||||
stream: Optional[bool] = False,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
|
||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
|
||||
) -> Response:
|
||||
@@ -601,7 +576,7 @@ async def chat(
|
||||
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
request, meta_log, q, (n or 5), (d or math.inf), conversation_command
|
||||
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
|
||||
)
|
||||
online_results: Dict = dict()
|
||||
|
||||
@@ -647,11 +622,8 @@ async def chat(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
client=client,
|
||||
user_agent=user_agent,
|
||||
referer=referer,
|
||||
host=host,
|
||||
metadata=chat_metadata,
|
||||
**common.__dict__,
|
||||
)
|
||||
|
||||
if llm_response is None:
|
||||
@@ -678,6 +650,7 @@ async def chat(
|
||||
|
||||
async def extract_references_and_questions(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
meta_log: dict,
|
||||
q: str,
|
||||
n: int,
|
||||
@@ -710,7 +683,16 @@ async def extract_references_and_questions(
|
||||
# Infer search queries from user message
|
||||
with timer("Extracting search queries took", logger):
|
||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||
if await ConversationAdapters.ahas_offline_chat():
|
||||
offline_chat_config = await ConversationAdapters.aget_offline_chat_conversation_config()
|
||||
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
||||
if conversation_config is None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
||||
if (
|
||||
offline_chat_config
|
||||
and offline_chat_config.enabled
|
||||
and conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE
|
||||
):
|
||||
using_offline_chat = True
|
||||
offline_chat = await ConversationAdapters.get_offline_chat()
|
||||
chat_model = offline_chat.chat_model
|
||||
@@ -722,7 +704,7 @@ async def extract_references_and_questions(
|
||||
inferred_queries = extract_questions_offline(
|
||||
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||
)
|
||||
elif await ConversationAdapters.has_openai_chat():
|
||||
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||
openai_chat = await ConversationAdapters.get_openai_chat()
|
||||
api_key = openai_chat_config.api_key
|
||||
@@ -744,9 +726,9 @@ async def extract_references_and_questions(
|
||||
r=True,
|
||||
max_distance=d,
|
||||
dedupe=False,
|
||||
common=common,
|
||||
)
|
||||
)
|
||||
# Dedupe the results again, as duplicates may be returned across queries.
|
||||
result_list = text_search.deduplicated_search_responses(result_list)
|
||||
compiled_references = [item.additional["compiled"] for item in result_list]
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@ from datetime import datetime
|
||||
from functools import partial
|
||||
import logging
|
||||
from time import time
|
||||
from typing import Iterator, List, Optional, Union, Tuple, Dict, Any
|
||||
from typing import Annotated, Iterator, List, Optional, Union, Tuple, Dict, Any
|
||||
|
||||
# External Packages
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi import HTTPException, Header, Request, Depends
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils import state
|
||||
@@ -232,3 +232,20 @@ class ApiUserRateLimiter:
|
||||
|
||||
# Add the current request to the cache
|
||||
user_requests.append(time())
|
||||
|
||||
|
||||
class CommonQueryParamsClass:
|
||||
def __init__(
|
||||
self,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
):
|
||||
self.client = client
|
||||
self.user_agent = user_agent
|
||||
self.referer = referer
|
||||
self.host = host
|
||||
|
||||
|
||||
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
|
||||
|
||||
@@ -63,7 +63,7 @@ async def update(
|
||||
request: Request,
|
||||
files: list[UploadFile],
|
||||
force: bool = False,
|
||||
t: Optional[Union[state.SearchType, str]] = None,
|
||||
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
@@ -182,13 +182,16 @@ def configure_content(
|
||||
files: Optional[dict[str, dict[str, str]]],
|
||||
search_models: SearchModels,
|
||||
regenerate: bool = False,
|
||||
t: Optional[state.SearchType] = None,
|
||||
t: Optional[state.SearchType] = state.SearchType.All,
|
||||
full_corpus: bool = True,
|
||||
user: KhojUser = None,
|
||||
) -> tuple[Optional[ContentIndex], bool]:
|
||||
content_index = ContentIndex()
|
||||
|
||||
success = True
|
||||
if t is not None and t in [type.value for type in state.SearchType]:
|
||||
t = state.SearchType(t)
|
||||
|
||||
if t is not None and not t.value in [type.value for type in state.SearchType]:
|
||||
logger.warning(f"🚨 Invalid search type: {t}")
|
||||
return None, False
|
||||
@@ -201,7 +204,7 @@ def configure_content(
|
||||
|
||||
try:
|
||||
# Initialize Org Notes Search
|
||||
if (search_type == None or search_type == state.SearchType.Org.value) and files["org"]:
|
||||
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Org.value) and files["org"]:
|
||||
logger.info("🦄 Setting up search for orgmode notes")
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
text_search.setup(
|
||||
@@ -217,7 +220,9 @@ def configure_content(
|
||||
|
||||
try:
|
||||
# Initialize Markdown Search
|
||||
if (search_type == None or search_type == state.SearchType.Markdown.value) and files["markdown"]:
|
||||
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Markdown.value) and files[
|
||||
"markdown"
|
||||
]:
|
||||
logger.info("💎 Setting up search for markdown notes")
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
text_search.setup(
|
||||
@@ -234,7 +239,7 @@ def configure_content(
|
||||
|
||||
try:
|
||||
# Initialize PDF Search
|
||||
if (search_type == None or search_type == state.SearchType.Pdf.value) and files["pdf"]:
|
||||
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Pdf.value) and files["pdf"]:
|
||||
logger.info("🖨️ Setting up search for pdf")
|
||||
# Extract Entries, Generate PDF Embeddings
|
||||
text_search.setup(
|
||||
@@ -251,7 +256,9 @@ def configure_content(
|
||||
|
||||
try:
|
||||
# Initialize Plaintext Search
|
||||
if (search_type == None or search_type == state.SearchType.Plaintext.value) and files["plaintext"]:
|
||||
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Plaintext.value) and files[
|
||||
"plaintext"
|
||||
]:
|
||||
logger.info("📄 Setting up search for plaintext")
|
||||
# Extract Entries, Generate Plaintext Embeddings
|
||||
text_search.setup(
|
||||
@@ -269,7 +276,7 @@ def configure_content(
|
||||
try:
|
||||
# Initialize Image Search
|
||||
if (
|
||||
(search_type == None or search_type == state.SearchType.Image.value)
|
||||
(search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value)
|
||||
and content_config
|
||||
and content_config.image
|
||||
and search_models.image_search
|
||||
@@ -286,7 +293,9 @@ def configure_content(
|
||||
|
||||
try:
|
||||
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||
if (search_type == None or search_type == state.SearchType.Github.value) and github_config is not None:
|
||||
if (
|
||||
search_type == state.SearchType.All.value or search_type == state.SearchType.Github.value
|
||||
) and github_config is not None:
|
||||
logger.info("🐙 Setting up search for github")
|
||||
# Extract Entries, Generate Github Embeddings
|
||||
text_search.setup(
|
||||
@@ -305,7 +314,9 @@ def configure_content(
|
||||
try:
|
||||
# Initialize Notion Search
|
||||
notion_config = NotionConfig.objects.filter(user=user).first()
|
||||
if (search_type == None or search_type in state.SearchType.Notion.value) and notion_config:
|
||||
if (
|
||||
search_type == state.SearchType.All.value or search_type in state.SearchType.Notion.value
|
||||
) and notion_config:
|
||||
logger.info("🔌 Setting up search for notion")
|
||||
text_search.setup(
|
||||
NotionToEntries,
|
||||
|
||||
@@ -229,7 +229,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
||||
|
||||
# Add the image metadata to the results
|
||||
results += [
|
||||
SearchResponse.parse_obj(
|
||||
SearchResponse.model_validate(
|
||||
{
|
||||
"entry": f"{image_files_url}/{target_image_name}",
|
||||
"score": f"{hit['score']:.9f}",
|
||||
@@ -237,7 +237,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
||||
"image_score": f"{hit['image_score']:.9f}",
|
||||
"metadata_score": f"{hit['metadata_score']:.9f}",
|
||||
},
|
||||
"corpus_id": hit["corpus_id"],
|
||||
"corpus_id": str(hit["corpus_id"]),
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
@@ -163,7 +163,7 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
|
||||
|
||||
else:
|
||||
hit_ids.add(hit.corpus_id)
|
||||
yield SearchResponse.parse_obj(
|
||||
yield SearchResponse.model_validate(
|
||||
{
|
||||
"entry": hit.entry,
|
||||
"score": hit.score,
|
||||
|
||||
@@ -288,15 +288,15 @@ def generate_random_name():
|
||||
# List of adjectives and nouns to choose from
|
||||
adjectives = [
|
||||
"happy",
|
||||
"irritated",
|
||||
"annoyed",
|
||||
"serendipitous",
|
||||
"exuberant",
|
||||
"calm",
|
||||
"brave",
|
||||
"scared",
|
||||
"energetic",
|
||||
"chivalrous",
|
||||
"kind",
|
||||
"grumpy",
|
||||
"suave",
|
||||
]
|
||||
nouns = ["dog", "cat", "falcon", "whale", "turtle", "rabbit", "hamster", "snake", "spider", "elephant"]
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from khoj.utils.helpers import to_snake_case_from_dash
|
||||
class ConfigBase(BaseModel):
|
||||
class Config:
|
||||
alias_generator = to_snake_case_from_dash
|
||||
allow_population_by_field_name = True
|
||||
populate_by_name = True
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
@@ -29,8 +29,8 @@ class TextConfigBase(ConfigBase):
|
||||
|
||||
|
||||
class TextContentConfig(ConfigBase):
|
||||
input_files: Optional[List[Path]]
|
||||
input_filter: Optional[List[str]]
|
||||
input_files: Optional[List[Path]] = None
|
||||
input_filter: Optional[List[str]] = None
|
||||
index_heading_entries: Optional[bool] = False
|
||||
|
||||
|
||||
@@ -50,31 +50,31 @@ class NotionContentConfig(ConfigBase):
|
||||
|
||||
|
||||
class ImageContentConfig(ConfigBase):
|
||||
input_directories: Optional[List[Path]]
|
||||
input_filter: Optional[List[str]]
|
||||
input_directories: Optional[List[Path]] = None
|
||||
input_filter: Optional[List[str]] = None
|
||||
embeddings_file: Path
|
||||
use_xmp_metadata: bool
|
||||
batch_size: int
|
||||
|
||||
|
||||
class ContentConfig(ConfigBase):
|
||||
org: Optional[TextContentConfig]
|
||||
image: Optional[ImageContentConfig]
|
||||
markdown: Optional[TextContentConfig]
|
||||
pdf: Optional[TextContentConfig]
|
||||
plaintext: Optional[TextContentConfig]
|
||||
github: Optional[GithubContentConfig]
|
||||
notion: Optional[NotionContentConfig]
|
||||
org: Optional[TextContentConfig] = None
|
||||
image: Optional[ImageContentConfig] = None
|
||||
markdown: Optional[TextContentConfig] = None
|
||||
pdf: Optional[TextContentConfig] = None
|
||||
plaintext: Optional[TextContentConfig] = None
|
||||
github: Optional[GithubContentConfig] = None
|
||||
notion: Optional[NotionContentConfig] = None
|
||||
|
||||
|
||||
class ImageSearchConfig(ConfigBase):
|
||||
encoder: str
|
||||
encoder_type: Optional[str]
|
||||
model_directory: Optional[Path]
|
||||
encoder_type: Optional[str] = None
|
||||
model_directory: Optional[Path] = None
|
||||
|
||||
|
||||
class SearchConfig(ConfigBase):
|
||||
image: Optional[ImageSearchConfig]
|
||||
image: Optional[ImageSearchConfig] = None
|
||||
|
||||
|
||||
class OpenAIProcessorConfig(ConfigBase):
|
||||
@@ -95,26 +95,26 @@ class ConversationProcessorConfig(ConfigBase):
|
||||
|
||||
|
||||
class ProcessorConfig(ConfigBase):
|
||||
conversation: Optional[ConversationProcessorConfig]
|
||||
conversation: Optional[ConversationProcessorConfig] = None
|
||||
|
||||
|
||||
class AppConfig(ConfigBase):
|
||||
should_log_telemetry: bool
|
||||
should_log_telemetry: bool = True
|
||||
|
||||
|
||||
class FullConfig(ConfigBase):
|
||||
content_type: Optional[ContentConfig] = None
|
||||
search_type: Optional[SearchConfig] = None
|
||||
processor: Optional[ProcessorConfig] = None
|
||||
app: Optional[AppConfig] = AppConfig(should_log_telemetry=True)
|
||||
app: Optional[AppConfig] = AppConfig()
|
||||
version: Optional[str] = None
|
||||
|
||||
|
||||
class SearchResponse(ConfigBase):
|
||||
entry: str
|
||||
score: float
|
||||
cross_score: Optional[float]
|
||||
additional: Optional[dict]
|
||||
cross_score: Optional[float] = None
|
||||
additional: Optional[dict] = None
|
||||
corpus_id: str
|
||||
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ def load_config_from_file(yaml_config_file: Path) -> dict:
|
||||
|
||||
def parse_config_from_string(yaml_config: dict) -> FullConfig:
|
||||
"Parse and validate config in YML string"
|
||||
return FullConfig.parse_obj(yaml_config)
|
||||
return FullConfig.model_validate(yaml_config)
|
||||
|
||||
|
||||
def parse_config_from_file(yaml_config_file):
|
||||
|
||||
Reference in New Issue
Block a user