Add support for a first-party client app to call into Khoj (Part 1) (#601)

* Add support for a first party client app
- Based on a client id and client secret, allow a first party app to call into the Khoj backend with a phone number identifier
- Add migration to add phone numbers to the KhojUser object
* Add plus in front of country code when registering a phone number.
- Decrease free tier limit to 5 (from 10)
- Return a response object when handling stripe webhooks
* Fix telemetry method which references authenticated user's client app
* Add better error handling for null phone numbers, simplify logic of authenticating user
* Pull the client_secret in the API call from the authorization header
* Add a migration merge to resolve phone number and other changes
This commit is contained in:
sabaimran
2024-01-18 05:54:14 -08:00
committed by GitHub
parent 9dfe1bb003
commit 039ed78253
11 changed files with 204 additions and 18 deletions

View File

@@ -76,6 +76,8 @@ dependencies = [
"rapidocr-onnxruntime == 1.3.8", "rapidocr-onnxruntime == 1.3.8",
"stripe == 7.3.0", "stripe == 7.3.0",
"openai-whisper >= 20231117", "openai-whisper >= 20231117",
"django-phonenumber-field == 7.3.0",
"phonenumbers == 8.13.27",
] ]
dynamic = ["version"] dynamic = ["version"]

View File

@@ -62,6 +62,7 @@ INSTALLED_APPS = [
"django.contrib.sessions", "django.contrib.sessions",
"django.contrib.messages", "django.contrib.messages",
"django.contrib.staticfiles", "django.contrib.staticfiles",
"phonenumber_field",
] ]
MIDDLEWARE = [ MIDDLEWARE = [

View File

@@ -9,6 +9,7 @@ import openai
import requests import requests
import schedule import schedule
from django.utils.timezone import make_aware from django.utils.timezone import make_aware
from fastapi import Response
from starlette.authentication import ( from starlette.authentication import (
AuthCredentials, AuthCredentials,
AuthenticationBackend, AuthenticationBackend,
@@ -20,27 +21,32 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection
from khoj.database.adapters import ( from khoj.database.adapters import (
ClientApplicationAdapters,
ConversationAdapters, ConversationAdapters,
SubscriptionState, SubscriptionState,
aget_or_create_user_by_phone_number,
aget_user_by_phone_number,
aget_user_subscription_state, aget_user_subscription_state,
get_all_users, get_all_users,
get_or_create_search_models, get_or_create_search_models,
) )
from khoj.database.models import KhojUser, Subscription from khoj.database.models import ClientApplication, KhojUser, Subscription
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, configure_search, load_content from khoj.routers.indexer import configure_content, configure_search, load_content
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import SearchType from khoj.utils.config import SearchType
from khoj.utils.fs_syncer import collect_files from khoj.utils.fs_syncer import collect_files
from khoj.utils.helpers import is_none_or_empty
from khoj.utils.rawconfig import FullConfig from khoj.utils.rawconfig import FullConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AuthenticatedKhojUser(SimpleUser): class AuthenticatedKhojUser(SimpleUser):
def __init__(self, user): def __init__(self, user, client_app: Optional[ClientApplication] = None):
self.object = user self.object = user
super().__init__(user.email) self.client_app = client_app
super().__init__(user.username)
class UserAuthenticationBackend(AuthenticationBackend): class UserAuthenticationBackend(AuthenticationBackend):
@@ -108,6 +114,53 @@ class UserAuthenticationBackend(AuthenticationBackend):
if subscribed: if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
# Get query params for client_id and client_secret
client_id = request.query_params.get("client_id")
if client_id:
# Get the client secret, which is passed in the Authorization header
client_secret = request.headers["Authorization"].split("Bearer ")[1]
if not client_secret:
return Response(
status_code=401,
content="Please provide a client secret in the Authorization header with a client_id query param.",
)
# Get the client application
client_application = await ClientApplicationAdapters.aget_client_application_by_id(client_id, client_secret)
if client_application is None:
return AuthCredentials(), UnauthenticatedUser()
# Get the identifier used for the user
phone_number = request.query_params.get("phone_number")
if is_none_or_empty(phone_number):
return AuthCredentials(), UnauthenticatedUser()
if not phone_number.startswith("+"):
phone_number = f"+{phone_number}"
create_if_not_exists = request.query_params.get("create_if_not_exists")
if create_if_not_exists:
user = await aget_or_create_user_by_phone_number(phone_number)
else:
user = await aget_user_by_phone_number(phone_number)
if user is None:
return AuthCredentials(), UnauthenticatedUser()
if not state.billing_enabled:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application)
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return (
AuthCredentials(["authenticated", "premium"]),
AuthenticatedKhojUser(user),
)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application)
if state.anonymous_mode: if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user: if user:

View File

@@ -17,6 +17,7 @@ from torch import Tensor
from khoj.database.models import ( from khoj.database.models import (
ChatModelOptions, ChatModelOptions,
ClientApplication,
Conversation, Conversation,
Entry, Entry,
GithubConfig, GithubConfig,
@@ -40,7 +41,7 @@ from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import generate_random_name from khoj.utils.helpers import generate_random_name, is_none_or_empty
class SubscriptionState(Enum): class SubscriptionState(Enum):
@@ -85,6 +86,28 @@ async def get_or_create_user(token: dict) -> KhojUser:
return user return user
async def aget_or_create_user_by_phone_number(phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number):
return None
user = await aget_user_by_phone_number(phone_number)
if not user:
user = await acreate_user_by_phone_number(phone_number)
return user
async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number):
return None
user, _ = await KhojUser.objects.filter(phone_number=phone_number).aupdate_or_create(
defaults={"username": phone_number, "phone_number": phone_number}
)
await user.asave()
await Subscription.objects.acreate(user=user, type="trial")
return user
async def get_or_create_user_by_email(email: str) -> KhojUser: async def get_or_create_user_by_email(email: str) -> KhojUser:
user, _ = await KhojUser.objects.filter(email=email).aupdate_or_create(defaults={"username": email, "email": email}) user, _ = await KhojUser.objects.filter(email=email).aupdate_or_create(defaults={"username": email, "email": email})
await user.asave() await user.asave()
@@ -187,6 +210,12 @@ async def get_user_by_token(token: dict) -> KhojUser:
return google_user.user return google_user.user
async def aget_user_by_phone_number(phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number):
return None
return await KhojUser.objects.filter(phone_number=phone_number).prefetch_related("subscription").afirst()
async def retrieve_user(session_id: str) -> KhojUser: async def retrieve_user(session_id: str) -> KhojUser:
session = SessionStore(session_key=session_id) session = SessionStore(session_key=session_id)
if not await sync_to_async(session.exists)(session_key=session_id): if not await sync_to_async(session.exists)(session_key=session_id):
@@ -270,6 +299,12 @@ async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
return new_config return new_config
class ClientApplicationAdapters:
@staticmethod
async def aget_client_application_by_id(client_id: str, client_secret: str):
return await ClientApplication.objects.filter(client_id=client_id, client_secret=client_secret).afirst()
class ConversationAdapters: class ConversationAdapters:
@staticmethod @staticmethod
def get_conversation_by_user(user: KhojUser): def get_conversation_by_user(user: KhojUser):
@@ -279,11 +314,11 @@ class ConversationAdapters:
return Conversation.objects.create(user=user) return Conversation.objects.create(user=user)
@staticmethod @staticmethod
async def aget_conversation_by_user(user: KhojUser): async def aget_conversation_by_user(user: KhojUser, client_application: ClientApplication = None):
conversation = Conversation.objects.filter(user=user) conversation = Conversation.objects.filter(user=user, client=client_application)
if await conversation.aexists(): if await conversation.aexists():
return await conversation.afirst() return await conversation.afirst()
return await Conversation.objects.acreate(user=user) return await Conversation.objects.acreate(user=user, client=client_application)
@staticmethod @staticmethod
async def adelete_conversation_by_user(user: KhojUser): async def adelete_conversation_by_user(user: KhojUser):

View File

@@ -7,6 +7,7 @@ from django.http import HttpResponse
from khoj.database.models import ( from khoj.database.models import (
ChatModelOptions, ChatModelOptions,
ClientApplication,
Conversation, Conversation,
KhojUser, KhojUser,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
@@ -19,10 +20,24 @@ from khoj.database.models import (
UserSearchModelConfig, UserSearchModelConfig,
) )
# Register your models here.
class KhojUserAdmin(UserAdmin):
list_display = (
"id",
"email",
"username",
"is_active",
"is_staff",
"is_superuser",
"phone_number",
)
search_fields = ("email", "username", "phone_number")
filter_horizontal = ("groups", "user_permissions")
fieldsets = (("Personal info", {"fields": ("phone_number",)}),) + UserAdmin.fieldsets
admin.site.register(KhojUser, UserAdmin) admin.site.register(KhojUser, KhojUserAdmin)
admin.site.register(ChatModelOptions) admin.site.register(ChatModelOptions)
admin.site.register(SpeechToTextModelOptions) admin.site.register(SpeechToTextModelOptions)
@@ -33,6 +48,7 @@ admin.site.register(Subscription)
admin.site.register(ReflectiveQuestion) admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig) admin.site.register(UserSearchModelConfig)
admin.site.register(TextToImageModelConfig) admin.site.register(TextToImageModelConfig)
admin.site.register(ClientApplication)
@admin.register(Conversation) @admin.register(Conversation)

View File

@@ -0,0 +1,46 @@
# Generated by Django 4.2.7 on 2024-01-04 12:22
import django.db.models.deletion
import phonenumber_field.modelfields
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0024_alter_entry_embeddings"),
]
operations = [
migrations.CreateModel(
name="ClientApplication",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(max_length=200)),
("client_id", models.CharField(max_length=200)),
("client_secret", models.CharField(max_length=200)),
],
options={
"abstract": False,
},
),
migrations.AddField(
model_name="khojuser",
name="phone_number",
field=phonenumber_field.modelfields.PhoneNumberField(
blank=True, default=None, max_length=128, null=True, region=None
),
),
migrations.AddField(
model_name="conversation",
name="client",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.clientapplication",
),
),
]

View File

@@ -0,0 +1,13 @@
# Generated by Django 4.2.7 on 2024-01-18 13:24
from typing import List
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0025_clientapplication_khojuser_phone_number_and_more"),
("database", "0026_searchmodelconfig_cross_encoder_inference_endpoint_and_more"),
]
operations: List[str] = []

View File

@@ -3,6 +3,7 @@ import uuid
from django.contrib.auth.models import AbstractUser from django.contrib.auth.models import AbstractUser
from django.db import models from django.db import models
from pgvector.django import VectorField from pgvector.django import VectorField
from phonenumber_field.modelfields import PhoneNumberField
class BaseModel(models.Model): class BaseModel(models.Model):
@@ -13,8 +14,18 @@ class BaseModel(models.Model):
abstract = True abstract = True
class ClientApplication(BaseModel):
name = models.CharField(max_length=200)
client_id = models.CharField(max_length=200)
client_secret = models.CharField(max_length=200)
def __str__(self):
return self.name
class KhojUser(AbstractUser): class KhojUser(AbstractUser):
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False)) uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
phone_number = PhoneNumberField(null=True, default=None, blank=True)
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
if not self.uuid: if not self.uuid:
@@ -165,6 +176,7 @@ class UserSearchModelConfig(BaseModel):
class Conversation(BaseModel): class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict) conversation_log = models.JSONField(default=dict)
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
class ReflectiveQuestion(BaseModel): class ReflectiveQuestion(BaseModel):

View File

@@ -359,8 +359,8 @@ async def chat(
n: Optional[int] = 5, n: Optional[int] = 5,
d: Optional[float] = 0.18, d: Optional[float] = 0.18,
stream: Optional[bool] = False, stream: Optional[bool] = False,
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)), rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24)),
) -> Response: ) -> Response:
user: KhojUser = request.user.object user: KhojUser = request.user.object
q = unquote(q) q = unquote(q)
@@ -372,7 +372,7 @@ async def chat(
q = q.replace(f"/{conversation_command.value}", "").strip() q = q.replace(f"/{conversation_command.value}", "").strip()
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log meta_log = (await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app)).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
@@ -392,7 +392,11 @@ async def chat(
elif conversation_command == ConversationCommand.Notes and not await EntryAdapters.auser_has_entries(user): elif conversation_command == ConversationCommand.Notes and not await EntryAdapters.auser_has_entries(user):
no_entries_found_format = no_entries_found.format() no_entries_found_format = no_entries_found.format()
if stream:
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200) return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
else:
response_obj = {"response": no_entries_found_format}
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
elif conversation_command == ConversationCommand.Online: elif conversation_command == ConversationCommand.Online:
try: try:

View File

@@ -13,7 +13,12 @@ from fastapi import Depends, Header, HTTPException, Request, UploadFile
from starlette.authentication import has_required_scope from starlette.authentication import has_required_scope
from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig from khoj.database.models import (
ClientApplication,
KhojUser,
Subscription,
TextToImageModelConfig,
)
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.chat_model import ( from khoj.processor.conversation.offline.chat_model import (
converse_offline, converse_offline,
@@ -74,6 +79,7 @@ def update_telemetry_state(
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
): ):
user: KhojUser = request.user.object if request.user.is_authenticated else None user: KhojUser = request.user.object if request.user.is_authenticated else None
client_app: ClientApplication = request.user.client_app if request.user.is_authenticated else None
subscription: Subscription = user.subscription if user and hasattr(user, "subscription") else None subscription: Subscription = user.subscription if user and hasattr(user, "subscription") else None
user_state = { user_state = {
"client_host": request.client.host if request.client else None, "client_host": request.client.host if request.client else None,
@@ -83,6 +89,7 @@ def update_telemetry_state(
"server_id": str(user.uuid) if user else None, "server_id": str(user.uuid) if user else None,
"subscription_type": subscription.type if subscription else None, "subscription_type": subscription.type if subscription else None,
"is_recurring": subscription.is_recurring if subscription else None, "is_recurring": subscription.is_recurring if subscription else None,
"client_id": str(client_app.name) if client_app else None,
} }
if metadata: if metadata:
@@ -113,10 +120,6 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.Default return ConversationCommand.Default
async def construct_conversation_logs(user: KhojUser):
return (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
async def agenerate_chat_response(*args): async def agenerate_chat_response(*args):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, generate_chat_response, *args) return await loop.run_in_executor(executor, generate_chat_response, *args)

View File

@@ -5,6 +5,7 @@ from datetime import datetime, timezone
import stripe import stripe
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.responses import Response
from starlette.authentication import requires from starlette.authentication import requires
from khoj.database import adapters from khoj.database import adapters