mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Use Sigmoid to normalize cross-encoder score between 0-1
- While sigmoid normalization isn't required for reranking. Normalizing score to distance metrics for both encoder and cross encoder scores is useful to reason about them - Softmax wasn't required as don't need probabilities, sigmoid is good enough to get distance metric
This commit is contained in:
@@ -11,10 +11,6 @@ from pgvector.django import CosineDistance
|
|||||||
from django.db.models.manager import BaseManager
|
from django.db.models.manager import BaseManager
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from pgvector.django import CosineDistance
|
|
||||||
from django.db.models.manager import BaseManager
|
|
||||||
from django.db.models import Q
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
# Import sync_to_async from Django Channels
|
# Import sync_to_async from Django Channels
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from khoj.utils.helpers import get_device
|
from khoj.utils.helpers import get_device
|
||||||
from khoj.utils.rawconfig import SearchResponse
|
from khoj.utils.rawconfig import SearchResponse
|
||||||
@@ -26,6 +27,6 @@ class CrossEncoderModel:
|
|||||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
||||||
|
|
||||||
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
|
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
|
||||||
cross__inp = [[query, hit.additional[key]] for hit in hits]
|
cross_inp = [[query, hit.additional[key]] for hit in hits]
|
||||||
cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True)
|
cross_scores = self.cross_encoder_model.predict(cross_inp, activation_fct=nn.Sigmoid())
|
||||||
return cross_scores
|
return cross_scores
|
||||||
|
|||||||
Reference in New Issue
Block a user