from typing import List, Callable import numpy as np from PIL import Image import torch def l2_normalize(embedding: np.ndarray) -> np.ndarray: """Normalize vector using L2 norm. Args: embedding (np.ndarray): Input vector to normalize. Returns: np.ndarray: Normalized vector. """ # Compute the L2 norm of the input vector norm = np.linalg.norm(embedding) # Return the normalized vector if norm is greater than 0; # otherwise, return the original vector return embedding / norm if norm > 0 else embedding def encode_image( image: Image.Image, preprocess: Callable[[Image.Image], torch.Tensor], model: torch.nn.Module, device: torch.device, ) -> List[float]: """Preprocess and encode an image using input model. This function performs the following steps: 1. Preprocess the image to create a tensor. 2. Move the tensor to the specified device (CPU or GPU). 3. Generate image features using the model. 4. Normalize the resulting embedding. Args: image (Image.Image): Input image to encode. preprocess (Callable[[Image.Image], torch.Tensor]): A callable function to preprocess the image. model (torch.nn.Module): The model used for encoding. device (torch.device): The device to which the image tensor is sent. Returns: List[float]: A list representing the normalized embedding. """ # Preprocess the input image and add a batch dimension image_input = preprocess(image).unsqueeze(0).to(device) # Use the model to encode the image without computing gradients with torch.no_grad(): image_features = model.encode_image(image_input) # Extract the first (and only) embedding from the batch and move it to CPU embedding = image_features[0].cpu().numpy() # Normalize the embedding using L2 normalization embedding_norm = l2_normalize(embedding) # Convert the normalized NumPy array to a list and return it return embedding_norm.tolist()