diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index 4f71c7aa..61bb82ce 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -134,10 +134,11 @@ async def set_user_subscription( return None -def get_user_subscription_state(user_subscription: Subscription) -> str: +def get_user_subscription_state(email: str) -> str: """Get subscription state of user Valid state transitions: trial -> subscribed <-> unsubscribed OR expired """ + user_subscription = Subscription.objects.filter(user__email=email).first() if not user_subscription: return "trial" elif user_subscription.type == Subscription.Type.TRIAL: @@ -450,5 +451,5 @@ class EntryAdapters: return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct() @staticmethod - def get_unique_file_source(user: KhojUser): - return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct() + def get_unique_file_sources(user: KhojUser): + return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all() diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 17d623ff..5f273cc3 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -37,8 +37,7 @@ templates = Jinja2Templates(directory=constants.web_directory) def index(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription = adapters.get_user_subscription(user.email) - user_subscription_state = get_user_subscription_state(user_subscription) + user_subscription_state = get_user_subscription_state(user.email) return templates.TemplateResponse( "chat.html", @@ -56,8 +55,7 @@ def index(request: Request): def index_post(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription = adapters.get_user_subscription(user.email) - user_subscription_state = get_user_subscription_state(user_subscription) + user_subscription_state = get_user_subscription_state(user.email) return templates.TemplateResponse( "chat.html", @@ -75,8 +73,7 @@ def index_post(request: Request): def search_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription = adapters.get_user_subscription(user.email) - user_subscription_state = get_user_subscription_state(user_subscription) + user_subscription_state = get_user_subscription_state(user.email) return templates.TemplateResponse( "search.html", @@ -94,8 +91,7 @@ def search_page(request: Request): def chat_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription = adapters.get_user_subscription(user.email) - user_subscription_state = get_user_subscription_state(user_subscription) + user_subscription_state = get_user_subscription_state(user.email) return templates.TemplateResponse( "chat.html", @@ -130,28 +126,27 @@ def login_page(request: Request): def config_page(request: Request): user: KhojUser = request.user.object user_picture = request.session.get("user", {}).get("picture") + user_subscription_state = get_user_subscription_state(user.email) user_subscription = adapters.get_user_subscription(user.email) - user_subscription_state = get_user_subscription_state(user_subscription) subscription_renewal_date = ( user_subscription.renewal_date.strftime("%d %b %Y") if user_subscription and user_subscription.renewal_date else None ) - enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all()) + enabled_content_source = set(EntryAdapters.get_unique_file_sources(user)) successfully_configured = { "computer": ("computer" in enabled_content_source), "github": ("github" in enabled_content_source), "notion": ("notion" in enabled_content_source), } + selected_conversation_config = ConversationAdapters.get_conversation_config(user) conversation_options = ConversationAdapters.get_conversation_processor_options().all() all_conversation_options = list() for conversation_option in conversation_options: all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id}) - selected_conversation_config = ConversationAdapters.get_conversation_config(user) - return templates.TemplateResponse( "config.html", context={ @@ -176,8 +171,7 @@ def config_page(request: Request): def github_config_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription = adapters.get_user_subscription(user.email) - user_subscription_state = get_user_subscription_state(user_subscription) + user_subscription_state = get_user_subscription_state(user.email) current_github_config = get_user_github_config(user) if current_github_config: @@ -216,8 +210,7 @@ def github_config_page(request: Request): def notion_config_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription = adapters.get_user_subscription(user.email) - user_subscription_state = get_user_subscription_state(user_subscription) + user_subscription_state = adapters.get_user_subscription(user.email) current_notion_config = get_user_notion_config(user) current_config = NotionContentConfig( @@ -243,8 +236,7 @@ def notion_config_page(request: Request): def computer_config_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription = adapters.get_user_subscription(user.email) - user_subscription_state = get_user_subscription_state(user_subscription) + user_subscription_state = get_user_subscription_state(user.email) return templates.TemplateResponse( "content_source_computer_input.html",