mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Search for images similar to query image provided by the user
Example user passes path to an image in query. e.g ~/Pictures/photo.jpg The script should return images in images_embedding most similar to the query image
This commit is contained in:
@@ -48,8 +48,17 @@ def compute_embeddings(image_names, model, embeddings_file, verbose=False):
|
|||||||
return image_embeddings
|
return image_embeddings
|
||||||
|
|
||||||
|
|
||||||
def search(query, image_embeddings, model, count=3):
|
def search(query, image_embeddings, model, count=3, verbose=False):
|
||||||
# First, we encode the query (which can either be an image or a text string)
|
# Set query to image content if query is a filepath
|
||||||
|
if pathlib.Path(query).expanduser().is_file():
|
||||||
|
query_imagepath = pathlib.Path(query).expanduser().resolve(strict=True)
|
||||||
|
query = copy.deepcopy(Image.open(query_imagepath))
|
||||||
|
if verbose:
|
||||||
|
print(f"Find Images similar to Image at {query_imagepath}")
|
||||||
|
else:
|
||||||
|
print(f"Find Images by Text: {query}")
|
||||||
|
|
||||||
|
# Now we encode the query (which can either be an image or a text string)
|
||||||
query_embedding = model.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
query_embedding = model.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||||
|
|
||||||
# Then, we use the util.semantic_search function, which computes the cosine-similarity
|
# Then, we use the util.semantic_search function, which computes the cosine-similarity
|
||||||
@@ -95,7 +104,7 @@ if __name__ == '__main__':
|
|||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
# query notes
|
# query notes
|
||||||
hits = search(user_query, image_embeddings, model, args.results_count)
|
hits = search(user_query, image_embeddings, model, args.results_count, args.verbose)
|
||||||
|
|
||||||
# render results
|
# render results
|
||||||
render_results(hits, image_names, args.image_directory, count=args.results_count)
|
render_results(hits, image_names, args.image_directory, count=args.results_count)
|
||||||
|
|||||||
Reference in New Issue
Block a user