diff --git a/pyproject.toml b/pyproject.toml index 41e8f2d9..ec31a5f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ dependencies = [ "itsdangerous == 2.1.2", "httpx == 0.28.1", "pgvector == 0.2.4", + "pgserver == 0.1.4", "psycopg2-binary == 2.9.9", "lxml == 4.9.3", "tzdata == 2023.3", diff --git a/src/khoj/app/settings.py b/src/khoj/app/settings.py index 48879f60..c90f01eb 100644 --- a/src/khoj/app/settings.py +++ b/src/khoj/app/settings.py @@ -10,6 +10,8 @@ For the full list of settings and their values, see https://docs.djangoproject.com/en/4.2/ref/settings/ """ +import atexit +import logging import os from pathlib import Path @@ -119,13 +121,71 @@ CLOSE_CONNECTIONS_AFTER_REQUEST = True # Database # https://docs.djangoproject.com/en/4.2/ref/settings/#databases DATA_UPLOAD_MAX_NUMBER_FIELDS = 20000 + +# Default PostgreSQL configuration +DB_NAME = os.getenv("POSTGRES_DB", "khoj") +DB_HOST = os.getenv("POSTGRES_HOST", "localhost") +DB_PORT = os.getenv("POSTGRES_PORT", "5432") + +# Use pgserver if env var explicitly set to true +USE_EMBEDDED_DB = is_env_var_true("USE_EMBEDDED_DB") + +if USE_EMBEDDED_DB: + # Set up logging for pgserver + logger = logging.getLogger("pgserver_django") + logger.setLevel(logging.INFO) + if not logger.handlers: + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) + logger.addHandler(handler) + + try: + import pgserver + + # Set up data directory + PGSERVER_DATA_DIR = os.path.join(BASE_DIR, "pgserver_data") + os.makedirs(PGSERVER_DATA_DIR, exist_ok=True) + + logger.debug(f"Initializing embedded Postgres DB with data directory: {PGSERVER_DATA_DIR}") + + # Start server + PGSERVER_INSTANCE = pgserver.get_server(PGSERVER_DATA_DIR) + + # Create pgvector extension, if not already exists + PGSERVER_INSTANCE.psql("CREATE EXTENSION IF NOT EXISTS vector;") + + # Create database, if not already exists + db_exists_result = PGSERVER_INSTANCE.psql(f"SELECT 1 FROM pg_database WHERE datname = '{DB_NAME}';") + db_exists = "(1 row)" in db_exists_result # Check for actual row in result + if not db_exists: + logger.info(f"Creating database: {DB_NAME}") + PGSERVER_INSTANCE.psql(f"CREATE DATABASE {DB_NAME};") + + # Register cleanup + def cleanup_pgserver(): + if PGSERVER_INSTANCE: + logger.debug("Shutting down embedded Postgres DB") + PGSERVER_INSTANCE.cleanup() + + atexit.register(cleanup_pgserver) + + # Update database configuration for pgserver + DB_HOST = PGSERVER_DATA_DIR + DB_PORT = "" # pgserver uses Unix socket, so port is empty + + logger.info("Embedded Postgres DB started successfully") + + except Exception as e: + logger.error(f"Error initializing embedded Postgres DB: {str(e)}. Use standard PostgreSQL server.") + +# Set the database configuration DATABASES = { "default": { "ENGINE": "django.db.backends.postgresql", - "HOST": os.getenv("POSTGRES_HOST", "localhost"), - "PORT": os.getenv("POSTGRES_PORT", "5432"), + "HOST": DB_HOST, + "PORT": DB_PORT, "USER": os.getenv("POSTGRES_USER", "postgres"), - "NAME": os.getenv("POSTGRES_DB", "khoj"), + "NAME": DB_NAME, "PASSWORD": os.getenv("POSTGRES_PASSWORD", "postgres"), "CONN_MAX_AGE": 0, "CONN_HEALTH_CHECKS": True,