From 62704cac0996f14549c69168daa307d9a47437e0 Mon Sep 17 00:00:00 2001 From: sabaimran <65192171+sabaimran@users.noreply.github.com> Date: Sun, 9 Jul 2023 15:29:26 -0700 Subject: [PATCH] Add a plugin which allows users to index their Notion pages (#284) * For the demo instance, re-instate the scheduler, but infrequently for api updates - In constants, determine the cadence based on whether it's a demo instance or not - This allow us to collect telemetry again. This will also allow us to save the chat session * Conditionally skip updating the index altogether if it's a demo isntance * Add backend support for Notion data parsing - Add a NotionToJsonl class which parses the text of Notion documents made accessible to the API token - Make corresponding updates to the default config, raw config to support the new notion addition * Add corresponding views to support configuring Notion from the web-based settings page - Support backend APIs for deleting/configuring notion setup as well - Streamline some of the index updating code * Use defaults for search and chat queries results count * Update pagination of retrieving pages from Notion * Update state conversation processor when update is hit * frequency_penalty should be passed to gpt through kwargs * Add check for notion in render_multiple method * Add headings to Notion render * Revert results count slider and split Notion files by blocks * Clean/fix misc things in the function to update index - Use the successText and errorText variables appropriately - Name parameters in function calls - Add emojis, woohoo * Clean up and further modularize code for processing data in Notion --- src/khoj/configure.py | 15 +- .../interface/web/assets/icons/notion.svg | 4 + src/khoj/interface/web/chat.html | 1 + src/khoj/interface/web/config.html | 97 ++++--- .../web/content_type_notion_input.html | 86 +++++++ src/khoj/interface/web/index.html | 8 +- src/khoj/processor/conversation/gpt.py | 3 +- src/khoj/processor/notion/notion_to_jsonl.py | 243 ++++++++++++++++++ src/khoj/processor/text_to_jsonl.py | 7 +- src/khoj/routers/api.py | 50 +++- src/khoj/routers/web_client.py | 22 ++ src/khoj/search_type/text_search.py | 13 +- src/khoj/utils/config.py | 18 +- src/khoj/utils/constants.py | 5 + src/khoj/utils/rawconfig.py | 12 +- 15 files changed, 520 insertions(+), 64 deletions(-) create mode 100644 src/khoj/interface/web/assets/icons/notion.svg create mode 100644 src/khoj/interface/web/content_type_notion_input.html create mode 100644 src/khoj/processor/notion/notion_to_jsonl.py diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 036e798b..3b5426ae 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -17,6 +17,7 @@ from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl from khoj.processor.github.github_to_jsonl import GithubToJsonl +from khoj.processor.notion.notion_to_jsonl import NotionToJsonl from khoj.search_type import image_search, text_search from khoj.utils import constants, state from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel @@ -169,6 +170,18 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, regenerate=regenerate, filters=[DateFilter(), WordFilter(), FileFilter()], ) + + # Initialize Notion Search + if (t == None or t in state.SearchType) and config.content_type.notion: + logger.info("🔌 Setting up search for notion") + model.notion_search = text_search.setup( + NotionToJsonl, + config.content_type.notion, + search_config=config.search_type.asymmetric, + regenerate=regenerate, + filters=[DateFilter(), WordFilter(), FileFilter()], + ) + except Exception as e: logger.error("🚨 Failed to setup search") raise e @@ -248,7 +261,7 @@ def save_chat_session(): @schedule.repeat(schedule.every(59).minutes) def upload_telemetry(): - if not state.config or not state.config.app.should_log_telemetry or not state.telemetry: + if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry: message = "📡 No telemetry to upload" if not state.telemetry else "📡 Telemetry logging disabled" logger.debug(message) return diff --git a/src/khoj/interface/web/assets/icons/notion.svg b/src/khoj/interface/web/assets/icons/notion.svg new file mode 100644 index 00000000..bf6442f7 --- /dev/null +++ b/src/khoj/interface/web/assets/icons/notion.svg @@ -0,0 +1,4 @@ + + + + diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index d8872b9b..7914d3a3 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -93,6 +93,7 @@ // Decode message chunk from stream const chunk = decoder.decode(value, { stream: true }); + if (chunk.includes("### compiled references:")) { const additionalResponse = chunk.split("### compiled references:")[0]; new_response_text.innerHTML += additionalResponse; diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 69a2c4f4..665cc3ff 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -14,7 +14,6 @@ Configured {% endif %} -

Set repositories for Khoj to index

@@ -37,6 +36,37 @@
{% endif %} +
+
+ Notion +

+ Notion + {% if current_config.content_type.notion %} + Configured + {% endif %} +

+
+
+

Configure your settings from Notion

+
+
+ + {% if current_config.content_type.content %} + Update + {% else %} + Setup + {% endif %} + + +
+ {% if current_config.content_type.notion %} +
+ +
+ {% endif %} +
markdown @@ -224,40 +254,32 @@ var configure = document.getElementById("configure"); configure.addEventListener("click", function(event) { event.preventDefault(); - configure.disabled = true; - configure.innerHTML = "Configuring..."; - const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; - fetch('/api/update?&client=web', { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - 'X-CSRFToken': csrfToken - } - }) - .then(response => response.json()) - .then(data => { - console.log('Success:', data); - document.getElementById("status").innerHTML = "Configured successfully!"; - document.getElementById("status").style.display = "block"; - configure.disabled = false; - configure.innerHTML = "⚙️ Configured"; - }) - .catch((error) => { - console.error('Error:', error); - document.getElementById("status").innerHTML = "Unable to save configuration. Raise issue on Khoj Discord or Github."; - document.getElementById("status").style.display = "block"; - configure.disabled = false; - configure.innerHTML = "⚙️ Configure"; - }); + updateIndex( + force=false, + successText="Configured successfully!", + errorText="Unable to configure. Raise issue on Khoj Github or Discord.", + button=configure, + loadingText="Configuring...", + emoji="⚙️"); }); var reinitialize = document.getElementById("reinitialize"); reinitialize.addEventListener("click", function(event) { event.preventDefault(); - reinitialize.disabled = true; - reinitialize.innerHTML = "Reinitializing..."; + updateIndex( + force=true, + successText="Reinitialized successfully!", + errorText="Unable to reinitialize. Raise issue on Khoj Github or Discord.", + button=reinitialize, + loadingText="Reinitializing...", + emoji="🔄"); + }); + + function updateIndex(force, successText, errorText, button, loadingText, emoji) { const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; - fetch('/api/update?&client=web&force=True', { + button.disabled = true; + button.innerHTML = emoji + loadingText; + fetch('/api/update?&client=web&force=' + force, { method: 'GET', headers: { 'Content-Type': 'application/json', @@ -267,19 +289,22 @@ .then(response => response.json()) .then(data => { console.log('Success:', data); - document.getElementById("status").innerHTML = "Reinitialized successfully!"; + if (data.detail != null) { + throw new Error(data.detail); + } + document.getElementById("status").innerHTML = emoji + successText; document.getElementById("status").style.display = "block"; - reinitialize.disabled = false; - reinitialize.innerHTML = "🔄 Reinitialized"; + button.disabled = false; + button.innerHTML = '✅ Done!'; }) .catch((error) => { console.error('Error:', error); - document.getElementById("status").innerHTML = "Unable to reinitialize. Raise issue on Khoj Discord or Github."; + document.getElementById("status").innerHTML = emoji + errorText document.getElementById("status").style.display = "block"; - reinitialize.disabled = false; - reinitialize.innerHTML = "🔄 Reinitialize"; + button.disabled = false; + button.innerHTML = '⚠️ Unsuccessful'; }); - }); + } // Setup the results count slider const resultsCountSlider = document.getElementById('results-count-slider'); diff --git a/src/khoj/interface/web/content_type_notion_input.html b/src/khoj/interface/web/content_type_notion_input.html new file mode 100644 index 00000000..494ff7a3 --- /dev/null +++ b/src/khoj/interface/web/content_type_notion_input.html @@ -0,0 +1,86 @@ +{% extends "base_config.html" %} +{% block content %} +
+
+

+ Notion + Notion +

+
+ + + + + +
+ + + +
+ + + + + + + + + +
+ + + +
+ + + +
+
+ + +
+
+
+
+ +{% endblock %} diff --git a/src/khoj/interface/web/index.html b/src/khoj/interface/web/index.html index 000bc0e2..8702d0c6 100644 --- a/src/khoj/interface/web/index.html +++ b/src/khoj/interface/web/index.html @@ -71,6 +71,8 @@ html += render_markdown(query, [item]); } else if (item.additional.file.endsWith(".pdf")) { html += render_pdf(query, [item]); + } else if (item.additional.file.includes("notion.so")) { + html += `
` + `${item.additional.heading}` + `

${item.entry}

` + `
`; } }); return html; @@ -86,7 +88,7 @@ results = data.map(render_image).join(''); } else if (type === "pdf") { results = render_pdf(query, data); - } else if (type === "github" || type === "all") { + } else if (type === "github" || type === "all" || type === "notion") { results = render_multiple(query, data, type); } else { results = data.map((item) => `
` + `

${item.entry}

` + `
`).join("\n") @@ -127,7 +129,7 @@ setQueryFieldInUrl(query); // Execute Search and Render Results - url = createRequestUrl(query, type, results_count, rerank); + url = createRequestUrl(query, type, results_count || 5, rerank); fetch(url) .then(response => response.json()) .then(data => { @@ -347,6 +349,7 @@ white-space: pre-wrap; } .results-pdf, + .results-notion, .results-plugin { text-align: left; white-space: pre-line; @@ -404,6 +407,7 @@ div#results-error, div.results-markdown, + div.results-notion, div.results-org, div.results-pdf { text-align: left; diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py index 226af3fb..e053be15 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/gpt.py @@ -32,8 +32,7 @@ def summarize(session, model, api_key=None, temperature=0.5, max_tokens=200): model_name=model, temperature=temperature, max_tokens=max_tokens, - frequency_penalty=0.2, - model_kwargs={"stop": ['"""']}, + model_kwargs={"stop": ['"""'], "frequency_penalty": 0.2}, openai_api_key=api_key, ) diff --git a/src/khoj/processor/notion/notion_to_jsonl.py b/src/khoj/processor/notion/notion_to_jsonl.py new file mode 100644 index 00000000..f7212d79 --- /dev/null +++ b/src/khoj/processor/notion/notion_to_jsonl.py @@ -0,0 +1,243 @@ +# Standard Packages +import logging + +# External Packages +import requests + +# Internal Packages +from khoj.utils.helpers import timer +from khoj.utils.rawconfig import Entry, NotionContentConfig +from khoj.processor.text_to_jsonl import TextToJsonl +from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data +from khoj.utils.rawconfig import Entry + +from enum import Enum + + +logger = logging.getLogger(__name__) + + +class NotionBlockType(Enum): + PARAGRAPH = "paragraph" + HEADING_1 = "heading_1" + HEADING_2 = "heading_2" + HEADING_3 = "heading_3" + BULLETED_LIST_ITEM = "bulleted_list_item" + NUMBERED_LIST_ITEM = "numbered_list_item" + TO_DO = "to_do" + TOGGLE = "toggle" + CHILD_PAGE = "child_page" + UNSUPPORTED = "unsupported" + BOOKMARK = "bookmark" + DIVIDER = "divider" + PDF = "pdf" + IMAGE = "image" + EMBED = "embed" + VIDEO = "video" + FILE = "file" + SYNCED_BLOCK = "synced_block" + TABLE_OF_CONTENTS = "table_of_contents" + COLUMN = "column" + EQUATION = "equation" + LINK_PREVIEW = "link_preview" + COLUMN_LIST = "column_list" + QUOTE = "quote" + BREADCRUMB = "breadcrumb" + LINK_TO_PAGE = "link_to_page" + CHILD_DATABASE = "child_database" + TEMPLATE = "template" + CALLOUT = "callout" + + +class NotionToJsonl(TextToJsonl): + def __init__(self, config: NotionContentConfig): + super().__init__(config) + self.config = config + self.session = requests.Session() + self.session.headers.update({"Authorization": f"Bearer {config.token}", "Notion-Version": "2022-02-22"}) + self.unsupported_block_types = [ + NotionBlockType.BOOKMARK.value, + NotionBlockType.DIVIDER.value, + NotionBlockType.CHILD_DATABASE.value, + NotionBlockType.TEMPLATE.value, + NotionBlockType.CALLOUT.value, + NotionBlockType.UNSUPPORTED.value, + ] + + self.display_block_block_types = [ + NotionBlockType.PARAGRAPH.value, + NotionBlockType.HEADING_1.value, + NotionBlockType.HEADING_2.value, + NotionBlockType.HEADING_3.value, + NotionBlockType.BULLETED_LIST_ITEM.value, + NotionBlockType.NUMBERED_LIST_ITEM.value, + NotionBlockType.TO_DO.value, + NotionBlockType.TOGGLE.value, + NotionBlockType.CHILD_PAGE.value, + NotionBlockType.BOOKMARK.value, + NotionBlockType.DIVIDER.value, + ] + + def process(self, previous_entries=None): + current_entries = [] + + # Get all pages + with timer("Getting all pages via search endpoint", logger=logger): + responses = [] + + while True: + result = self.session.post( + "https://api.notion.com/v1/search", + json={"page_size": 100}, + ).json() + responses.append(result) + if result["has_more"] == False: + break + else: + self.session.params = {"start_cursor": responses[-1]["next_cursor"]} + + for response in responses: + with timer("Processing response", logger=logger): + pages_or_databases = response["results"] + + # Get all pages content + for p_or_d in pages_or_databases: + with timer(f"Processing {p_or_d['object']} {p_or_d['id']}", logger=logger): + if p_or_d["object"] == "database": + # TODO: Handle databases + continue + elif p_or_d["object"] == "page": + page_entries = self.process_page(p_or_d) + current_entries.extend(page_entries) + + return self.update_entries_with_ids(current_entries, previous_entries) + + def process_page(self, page): + page_id = page["id"] + title, content = self.get_page_content(page_id) + + if title == None or content == None: + return [] + + current_entries = [] + curr_heading = "" + for block in content["results"]: + block_type = block.get("type") + + if block_type == None: + continue + block_data = block[block_type] + + if block_data.get("rich_text") == None or len(block_data["rich_text"]) == 0: + # There's no text to handle here. + continue + + raw_content = "" + if block_type in ["heading_1", "heading_2", "heading_3"]: + # If the current block is a heading, we can consider the previous block processing completed. + # Add it as an entry and move on to processing the next chunk of the page. + if raw_content != "": + current_entries.append( + Entry( + compiled=raw_content, + raw=raw_content, + heading=title, + file=page["url"], + ) + ) + curr_heading = block_data["rich_text"][0]["plain_text"] + else: + if curr_heading != "": + # Add the last known heading to the content for additional context + raw_content = self.process_heading(curr_heading) + for text in block_data["rich_text"]: + raw_content += self.process_text(text) + + if block.get("has_children", True): + raw_content += "\n" + raw_content = self.process_nested_children( + self.get_block_children(block["id"]), raw_content, block_type + ) + + if raw_content != "": + current_entries.append( + Entry( + compiled=raw_content, + raw=raw_content, + heading=title, + file=page["url"], + ) + ) + return current_entries + + def process_heading(self, heading): + return f"\n{heading}\n" + + def process_nested_children(self, children, raw_content, block_type=None): + for child in children["results"]: + child_type = child.get("type") + if child_type == None: + continue + child_data = child[child_type] + if child_data.get("rich_text") and len(child_data["rich_text"]) > 0: + for text in child_data["rich_text"]: + raw_content += self.process_text(text, block_type) + if child_data.get("has_children", True): + return self.process_nested_children(self.get_block_children(child["id"]), raw_content, block_type) + + return raw_content + + def process_text(self, text, block_type=None): + text_type = text.get("type", None) + if text_type in self.unsupported_block_types: + return "" + if text.get("href", None): + return f"{text['plain_text']}" + raw_text = text["plain_text"] + if text_type in self.display_block_block_types or block_type in self.display_block_block_types: + return f"\n{raw_text}\n" + return raw_text + + def get_block_children(self, block_id): + return self.session.get(f"https://api.notion.com/v1/blocks/{block_id}/children").json() + + def get_page(self, page_id): + return self.session.get(f"https://api.notion.com/v1/pages/{page_id}").json() + + def get_page_children(self, page_id): + return self.session.get(f"https://api.notion.com/v1/blocks/{page_id}/children").json() + + def get_page_content(self, page_id): + try: + page = self.get_page(page_id) + content = self.get_page_children(page_id) + except Exception as e: + logger.error(f"Error getting page {page_id}: {e}") + return None, None + properties = page["properties"] + title_field = "Title" if "Title" in properties else "title" + title = page["properties"][title_field]["title"][0]["text"]["content"] + return title, content + + def update_entries_with_ids(self, current_entries, previous_entries): + # Identify, mark and merge any new entries with previous entries + with timer("Identify new or updated entries", logger): + if not previous_entries: + entries_with_ids = list(enumerate(current_entries)) + else: + entries_with_ids = TextToJsonl.mark_entries_for_update( + current_entries, previous_entries, key="compiled", logger=logger + ) + + with timer("Write Notion entries to JSONL file", logger): + # Process Each Entry from all Notion entries + entries = list(map(lambda entry: entry[1], entries_with_ids)) + jsonl_data = TextToJsonl.convert_text_maps_to_jsonl(entries) + + # Compress JSONL formatted Data + if self.config.compressed_jsonl.suffix == ".gz": + compress_jsonl_data(jsonl_data, self.config.compressed_jsonl) + elif self.config.compressed_jsonl.suffix == ".jsonl": + dump_jsonl(jsonl_data, self.config.compressed_jsonl) + + return entries_with_ids diff --git a/src/khoj/processor/text_to_jsonl.py b/src/khoj/processor/text_to_jsonl.py index f7bca376..a4d01cf5 100644 --- a/src/khoj/processor/text_to_jsonl.py +++ b/src/khoj/processor/text_to_jsonl.py @@ -62,7 +62,7 @@ class TextToJsonl(ABC): @staticmethod def mark_entries_for_update( - current_entries: List[Entry], previous_entries: List[Entry], key="compiled", logger=None + current_entries: List[Entry], previous_entries: List[Entry], key="compiled", logger: logging.Logger = None ) -> List[Tuple[int, Entry]]: # Hash all current and previous entries to identify new entries with timer("Hash previous, current entries", logger): @@ -90,3 +90,8 @@ class TextToJsonl(ABC): entries_with_ids = existing_entries_sorted + new_entries return entries_with_ids + + @staticmethod + def convert_text_maps_to_jsonl(entries: List[Entry]) -> str: + # Convert each entry to JSON and write to JSONL file + return "".join([f"{entry.to_json()}\n" for entry in entries]) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 39593b44..8d5d67b2 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -28,6 +28,7 @@ from khoj.utils.rawconfig import ( TextContentConfig, ConversationProcessorConfig, GithubContentConfig, + NotionContentConfig, ) from khoj.utils.state import SearchType from khoj.utils import state, constants @@ -45,6 +46,11 @@ logger = logging.getLogger(__name__) # If it's a demo instance, prevent updating any of the configuration. if not state.demo: + def _initialize_config(): + if state.config is None: + state.config = FullConfig() + state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"]) + @api.get("/config/data", response_model=FullConfig) def get_config_data(): return state.config @@ -59,9 +65,7 @@ if not state.demo: @api.post("/config/data/content_type/github", status_code=200) async def set_content_config_github_data(updated_config: Union[GithubContentConfig, None]): - if not state.config: - state.config = FullConfig() - state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"]) + _initialize_config() if not state.config.content_type: state.config.content_type = ContentConfig(**{"github": updated_config}) @@ -74,6 +78,21 @@ if not state.demo: except Exception as e: return {"status": "error", "message": str(e)} + @api.post("/config/data/content_type/notion", status_code=200) + async def set_content_config_notion_data(updated_config: Union[NotionContentConfig, None]): + _initialize_config() + + if not state.config.content_type: + state.config.content_type = ContentConfig(**{"notion": updated_config}) + else: + state.config.content_type.notion = updated_config + + try: + save_config_to_file_updated_state() + return {"status": "ok"} + except Exception as e: + return {"status": "error", "message": str(e)} + @api.post("/delete/config/data/content_type/{content_type}", status_code=200) async def remove_content_config_data(content_type: str): if not state.config or not state.config.content_type: @@ -84,6 +103,8 @@ if not state.demo: if content_type == "github": state.model.github_search = None + elif content_type == "notion": + state.model.notion_search = None elif content_type == "plugins": state.model.plugin_search = None elif content_type == "pdf": @@ -114,9 +135,7 @@ if not state.demo: @api.post("/config/data/content_type/{content_type}", status_code=200) async def set_content_config_data(content_type: str, updated_config: Union[TextContentConfig, None]): - if not state.config: - state.config = FullConfig() - state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"]) + _initialize_config() if not state.config.content_type: state.config.content_type = ContentConfig(**{content_type: updated_config}) @@ -131,9 +150,8 @@ if not state.demo: @api.post("/config/data/processor/conversation", status_code=200) async def set_processor_conversation_config_data(updated_config: Union[ConversationProcessorConfig, None]): - if not state.config: - state.config = FullConfig() - state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"]) + _initialize_config() + state.config.processor = ProcessorConfig(conversation=updated_config) state.processor_config = configure_processor(state.config.processor) try: @@ -312,6 +330,20 @@ async def search( ) ] + if (t == SearchType.Notion or t == SearchType.All) and state.model.notion_search: + # query notion pages + search_futures += [ + executor.submit( + text_search.query, + user_query, + state.model.notion_search, + question_embedding=encoded_asymmetric_query, + rank_results=r or False, + score_threshold=score_threshold, + dedupe=dedupe or True, + ) + ] + # Query across each requested content types in parallel with timer("Query took", logger): for search_future in concurrent.futures.as_completed(search_futures): diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index afdd31ec..e3bb7dde 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -63,6 +63,28 @@ if not state.demo: "content_type_github_input.html", context={"request": request, "current_config": current_config} ) + @web_client.get("/config/content_type/notion", response_class=HTMLResponse) + def notion_config_page(request: Request): + default_copy = constants.default_config.copy() + default_notion = default_copy["content-type"]["notion"] # type: ignore + + default_config = TextContentConfig( + compressed_jsonl=default_notion["compressed-jsonl"], + embeddings_file=default_notion["embeddings-file"], + ) + + current_config = ( + state.config.content_type.notion + if state.config and state.config.content_type and state.config.content_type.notion + else default_config + ) + + current_config = json.loads(current_config.json()) + + return templates.TemplateResponse( + "content_type_notion_input.html", context={"request": request, "current_config": current_config} + ) + @web_client.get("/config/content_type/{content_type}", response_class=HTMLResponse) def content_config_page(request: Request, content_type: str): if content_type not in VALID_TEXT_CONTENT_TYPES: diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 0af5b0fc..09057f9a 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -15,7 +15,7 @@ from khoj.utils import state from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer from khoj.utils.config import TextSearchModel from khoj.utils.models import BaseEncoder -from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry +from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry from khoj.utils.jsonl import load_jsonl @@ -159,7 +159,11 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse] { "entry": entries[hit["corpus_id"]].raw, "score": f"{hit.get('cross-score') or hit.get('score')}", - "additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled}, + "additional": { + "file": entries[hit["corpus_id"]].file, + "compiled": entries[hit["corpus_id"]].compiled, + "heading": entries[hit["corpus_id"]].heading, + }, } ) for hit in hits[0:count] @@ -168,7 +172,7 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse] def setup( text_to_jsonl: Type[TextToJsonl], - config: TextContentConfig, + config: TextConfigBase, search_config: TextSearchConfig, regenerate: bool, filters: List[BaseFilter] = [], @@ -186,7 +190,8 @@ def setup( # Extract Updated Entries entries = extract_entries(config.compressed_jsonl) if is_none_or_empty(entries): - raise ValueError(f"No valid entries found in specified files: {config.input_files} or {config.input_filter}") + config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()]) + raise ValueError(f"No valid entries found in specified files: {config_params}") top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus # Compute or Load Embeddings diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 155cdcc6..7887e9cd 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints from enum import Enum from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Dict, List, Union # External Packages import torch @@ -23,6 +23,7 @@ class SearchType(str, Enum): Image = "image" Pdf = "pdf" Github = "github" + Notion = "notion" class ProcessorType(str, Enum): @@ -58,12 +59,13 @@ class ImageSearchModel: @dataclass class SearchModels: - org_search: TextSearchModel = None - markdown_search: TextSearchModel = None - pdf_search: TextSearchModel = None - image_search: ImageSearchModel = None - github_search: TextSearchModel = None - plugin_search: Dict[str, TextSearchModel] = None + org_search: Union[TextSearchModel, None] = None + markdown_search: Union[TextSearchModel, None] = None + pdf_search: Union[TextSearchModel, None] = None + image_search: Union[ImageSearchModel, None] = None + github_search: Union[TextSearchModel, None] = None + notion_search: Union[TextSearchModel, None] = None + plugin_search: Union[Dict[str, TextSearchModel], None] = None class ConversationProcessorConfigModel: @@ -78,4 +80,4 @@ class ConversationProcessorConfigModel: @dataclass class ProcessorConfigModel: - conversation: ConversationProcessorConfigModel = None + conversation: Union[ConversationProcessorConfigModel, None] = None diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index caf64ac2..f1de7d76 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -41,6 +41,11 @@ default_config = { "compressed-jsonl": "~/.khoj/content/github/github.jsonl.gz", "embeddings-file": "~/.khoj/content/github/github_embeddings.pt", }, + "notion": { + "token": None, + "compressed-jsonl": "~/.khoj/content/notion/notion.jsonl.gz", + "embeddings-file": "~/.khoj/content/notion/notion_embeddings.pt", + }, }, "search-type": { "symmetric": { diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index b13c7449..0172dc1f 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -52,6 +52,10 @@ class GithubContentConfig(TextConfigBase): repos: List[GithubRepoConfig] +class NotionContentConfig(TextConfigBase): + token: str + + class ImageContentConfig(ConfigBase): input_directories: Optional[List[Path]] input_filter: Optional[List[str]] @@ -77,6 +81,7 @@ class ContentConfig(ConfigBase): pdf: Optional[TextContentConfig] github: Optional[GithubContentConfig] plugins: Optional[Dict[str, TextContentConfig]] + notion: Optional[NotionContentConfig] class TextSearchConfig(ConfigBase): @@ -148,4 +153,9 @@ class Entry: @classmethod def from_dict(cls, dictionary: dict): - return cls(raw=dictionary["raw"], compiled=dictionary["compiled"], file=dictionary.get("file", None)) + return cls( + raw=dictionary["raw"], + compiled=dictionary["compiled"], + file=dictionary.get("file", None), + heading=dictionary.get("heading", None), + )