DDS-LOGO

Academic Sharing: How DINO Model Innovates Vertical E-commerce Image Search

This article is translated from "Vertical E-commerce Image Search Upgraded: DINO Model Brings Accurate Matching Experience" published by the AWS team in October 2024, original link: https://aws.amazon.com/cn/blogs/china/vertical-e-commerce-image-search-upgraded-dino-model-brings-accurate-matching-experience/, authors: Jiang Bingkun, Ji Junxiang, Lyu Haoran, Yin Zhenyu, Hong Dan, Hua Cheng.

When shopping online or offline, we sometimes struggle to find the right words to describe what we're looking for. As the saying goes: "A picture is worth a thousand words," showing an actual item or image is usually easier than describing it with text, especially when using apps or mini-programs to find products.

However, building an e-commerce image search system faces numerous challenges, such as:

(1) For private domain or category-specific e-commerce, some products and images have low differentiation;

(2) User-uploaded images differ significantly from product images in terms of shooting angle, lens distortion, lighting, and background interference;

(3) The system has strict response time requirements.

In this article, we'll introduce how to build a vertical model for shoes and clothing from scratch to achieve a low-latency, high-precision image search solution.

I. Business Background

Image search can enhance customer engagement in retail and e-commerce, especially for clothing retailers (clothes, pants, shoes, accessories, etc.). Clothing is one of the most important product categories in image search. Research shows that 36% of consumers have used image search, and 74% find traditional text search inadequate for locating the right products.

Due to industry characteristics, clothing items often have very high similarity, such as sports shoes and clothes. Most shoes have very similar shapes and styles, requiring very fine-grained features for identification. For example, different models of shoes can look very similar, as shown below.

图1.png

In this article, you'll learn how to build a similar product catalog similarity search solution. This solution primarily integrates Amazon SageMaker, Amazon Aurora MySQL, and Amazon OpenSearch for vector data storage.

II. Business Requirements

1. Efficient Object-Based Search

When multiple items or targets exist in a user's input image, the system allows users to search for specific objects, enabling them to focus only on products of interest rather than searching the entire image. This feature improves search efficiency, helping users find what they need faster.

2. Automatic Product Recognition

The system can automatically identify products in images. In the future, this feature can be integrated with e-commerce platforms to recommend relevant products based on recognized items, promoting sales.

3. Search Accuracy

When user search images and indexed images are taken from different angles and under different lighting conditions, the system can achieve over 85% accuracy in Top5 recall across thousands of categories, correctly matching products with relevant images based on visual feature analysis. High accuracy ensures the relevance of search results.

4. Security and Privacy

The system can be deployed privately and ensures compliance with relevant privacy regulations and compliance requirements.

5. Indexing and Storage

The system needs to efficiently index and store over 1 million images and related metadata, such as tags, descriptions, and other relevant information, to support fast search and retrieval.

III. Overall Solution

1. Architecture

图2.png

2. Solution Steps

2.1 Offline Processing (White Lines)

(1) Start a Notebook to read all images from S3;

(2) Call Bedrock for image tagging to filter training data;

(3) Save tagging results in Aurora MySQL;

(4) Launch SageMaker model training node and train using filtered data. Deploy the trained embedding model to SageMaker;

(5) Call the embedding model to generate embeddings for all existing product images and store results in OpenSearch.

2.2 Real-time Processing (Yellow Lines)

(1) Frontend loads pages and product images through CloudFront;

(2) CloudFront reads static data from S3;

(3) When uploading an image, CloudFront forwards the request to API Gateway;

(4) API Gateway forwards the request to EC2;

(5) EC2 sends the image to Lambda;

(6) Lambda sends the image to GroundingDINO for object detection. If no target items are detected, it returns to the frontend; if multiple target items are detected, it returns the coordinates to the frontend for user selection; if only one target item is detected or the user has selected a target item, it crops the target image based on the rectangular frame returned by GroundingDINO and proceeds to the next step;

(7) The cropped target image is processed through Lambda;

(8) Lambda calls the embedding model to obtain vectors;

(9) Query OpenSearch with the vector to get the top 5 product codes;

(10) Query Aurora with product codes to get detailed product data and return to the frontend.

IV. Technical Challenges and Solutions

1. Image Preprocessing

1.1 Technical Challenges

(1) Some images unsuitable for training: Certain images may only show partial views of products (e.g., shoe soles), which may not be appropriate for model training.

(2) Inconsistent image quality and varying angles: Due to diverse image sources, there may be differences in image quality and shooting angles, affecting training effectiveness.

1.2 Solution

As shown in the figure below, we utilized the latest multimodal capabilities of large language models for image annotation. In our scenario, we designed a labeling system including "whether models appear," "number of models," "whether it's a real-world scene," "whether it's worn by models," "shooting angle," "partial or complete view," etc. Through these labels, we can filter out images like shoe soles that don't help with training and searching.

We also used these labels to divide the training and test sets. Real-world scene images were all assigned to the test set.

图3.png

2. Object Detection and Segmentation

2.1 Technical Challenges

(1) Search images uploaded by users cannot be pre-restricted, resulting in situations with no products or multiple products. How to determine if the target in an image belongs to the company's product categories.

(2) When multiple products are detected, user selection is needed: When an image contains multiple products, a mechanism is needed to allow users to select the product of interest.

2.2 Solution

Use Grounding DINO for object detection for shoes, hats, pants, etc. Then directly crop the corresponding rectangular blocks (preserving all elements within the rectangular block, including the background. SAM was not used to segment irregular objects, as we found that simply framing the target image was sufficient, while pixel-level segmentation with SAM actually reduced model effectiveness).

图4.png

First, we build a model package and upload it to an S3 bucket:

import boto3
import sagemaker
from sagemaker import serializers, deserializers
from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment
s3_model_prefix = "east-ai-models/grounded-sam"
!touch dummy
!rm -f model.tar.gz
!tar czvf model.tar.gz dummy
s3_model_artifact = sess.upload_data("model.tar.gz", bucket, s3_model_prefix)print(f"S3 Code or Model tar uploaded to --- > {s3_model_artifact}")
!rm -f dummy

Next, we prepare the code needed to create the model, with all code in the local "code" path:

endpoint_name ="grounded-sam"#%%
framework_version = '2.3.0'
py_version = 'py311'
instance_type = "ml.g4dn.xlarge"
endpoint_name ="grounded-sam"
model = PyTorchModel(
    model_data = s3_model_artifact,
    entry_point = 'inference.py',
    source_dir = "./code/",
    role = role,
    framework_version = framework_version,    py_version = py_version,)print("模型部署过程大约需要 7~8 分钟,请等待" + "."*20)
model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,)print("模型部署已完成,可以继续执行后续步骤" + "."*20)

Prepare custom inference script clip_inference.py. We load the model in model_fn and define inference logic in predict_fn, with core code as follows:

import os
import io
from PIL import Image
import numpy as np
import torch
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from groundingdino.util.inference import predict
import groundingdino.datasets.transforms as T
from huggingface_hub import hf_hub_download
import json
import boto3
import uuid
import math

def get_detection_boxes(image_source: Image, model: dict, prompt: str = "clothes . pants . hats . shoes") -> (list, list, list):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    box_treshold = 0.3
    text_treshold = 0.25
    transform = T.Compose([
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])
    image_transformed, _ = transform(image_source, None)try:
        boxes, logits, phrases = predict(
            model=model['dino'],
            image=image_transformed,
            caption=prompt,
            box_threshold=box_treshold,
            text_threshold=text_treshold,
            device='cuda')except Exception as e:print(e)return
    boxes_list = boxes.numpy().tolist()
    logits_list = logits.numpy().tolist()return boxes_list, logits_list, phrases

def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
    args = SLConfig.fromfile(cache_config_file)
    model = build_model(args)
    args.device = device
    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location=device)
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    model.cuda()
    _ = model.eval()return model

def model_fn(model_dir):
    ckpt_repo_id = "ShilongLiu/GroundingDINO"
    ckpt_filenmae = "groundingdino_swint_ogc.pth"
    ckpt_config_filename = "GroundingDINO_SwinT_OGC.cfg.py"
    model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
    model_dic = {'dino': model, 'sam': ''}return model_dic

def save_file_to_s3(mask_image, file_extension, output_mask_image_dir: str):# 图片存储到s3......return mask_image_output

def crop_images_from_boxes(image_source: Image, boxes: list, scale_factor: float = 1.0, target_size: int = 400) -> list:
    cropped_images = []
    width, height = image_source.size
for box in boxes:
        cx, cy, w, h = [coord * scale_factor for coord in box]# 计算边界框的左上角和右下角坐标
        x1 = max(0, math.floor((cx - w / 2) * width))
        y1 = max(0, math.floor((cy - h / 2) * height))
        x2 = min(width, math.ceil((cx + w / 2) * width))
        y2 = min(height, math.ceil((cy + h / 2) * height))# 如果边界框在图像范围内,则裁剪图像if x2 > x1 and y2 > y1:
            cropped_image = image_source.crop((x1, y1, x2, y2))# 调整裁剪后图像的大小
            cropped_width, cropped_height = cropped_image.size# 等比例调整到目标尺寸
            scale = min(target_size / cropped_width, target_size / cropped_height)
            new_width = int(cropped_width * scale)
            new_height = int(cropped_height * scale)
            cropped_image = cropped_image.resize((new_width, new_height), resample=Image.BICUBIC)
            cropped_images.append(cropped_image)return cropped_images

def predict_fn(input_data, model):print("=================Dino detect start=================")try:
        file_extension = os.path.splitext(input_data['input_image'])[1][1:].lower()
        dir_lst = input_data['input_image'].split('/')
        s3_client = boto3.client('s3')
        s3_response_object = s3_client.get_object(Bucket=dir_lst[2], Key='/'.join(dir_lst[3:]))
        img_bytes = s3_response_object['Body'].read()
        image_source = Image.open(io.BytesIO(img_bytes)).convert("RGB")if 'boxes' not in input_data:
            prompt = input_data['prompt']
            boxes, logits, phrases = get_detection_boxes(image_source, model, prompt)if len(boxes) == 0:return {"error_message": "The image does not contain any object needed"}elif len(boxes) > 1:return {"boxes": boxes, "file_type": file_extension, "logits": logits, "phrases": phrases}
        boxes = [input_data['boxes']] if 'boxes' in input_data else boxes
        cropped_images = crop_images_from_boxes(image_source, boxes)
        mask_image_output = save_file_to_s3(cropped_images[0], file_extension, input_data['output_mask_image_dir'])return {"mask_image_output": mask_image_output}except Exception as e:print(e)

3. Embedding Model

3.1 Technical Challenges

Traditional image embedding models often face the following issues when used for vector retrieval:

(1) Lack of annotated images: Training models requires large amounts of annotated image data, but obtaining these annotations can be costly and difficult.

(2) Models need high precision for fine-grained comparison: To accurately match similar products, embedding models need sufficient precision to capture subtle differences.

(3) Insufficient robustness of model output embeddings: They can be significantly affected by background, clothing deformation, shooting angle, lighting, etc.

(4) Need for private deployment options to ensure security and privacy: For security and privacy considerations, models may need to be deployed in local private environments.

(5) Models should be customizable and scalable: To meet different needs, models should have a certain degree of customization and scalability.

3.2 Solution

First, pre-train a DINO+VIT-based model on private product image data, which doesn't require annotation as DINO can naturally focus on the main subject without being easily distracted by the background. In the second stage, we use contrastive learning or classification to fine-tune the model to further improve recall capabilities. The figure below visualizes the attention layer of the DINO model, showing its advantages over traditional models. We can see that the DINO column shows how the model's attention can separate the subject from background interference, while traditional supervised algorithms' attention doesn't accurately capture the main subject in the image.

图5.png

In the algorithm development process, we evaluated DINO and DINO V2, Triplet Loss and Cross Entropy Loss, and compared VIT and CNN. Based on extensive experiments, we reached the following conclusions:

(1) Triplet loss is currently far less economical than cross entropy loss, not converging at all with the same number of training epochs (single-digit mAP). This is because cross entropy loss optimizes the entire sample distribution in one gradient update, while triplet loss only optimizes the sampled positive and negative samples, making training efficiency completely different. However, triplet loss's direct feature optimization approach is actually more suitable for vector matching tasks and might require larger batch sizes or more fine-tuned hyperparameters, along with more comprehensive hard negative sample mining.

(2) DINOv2 (DINO with MAE loss) is unsurprisingly worse than DINO in this scenario, with even large and giant versions of VIT-dinov2 underperforming compared to Base VIT-dino. The current hypothesis is that reconstruction-type losses (MAE loss) are not suitable for discriminative scenarios. In such scenarios, discriminative losses (Cross Entropy Loss) are more appropriate, focusing on lower-frequency features that are more suitable for discriminative tasks.

(3) DINOv1 is currently the most suitable pre-training algorithm for vector search, and this pre-training method can even compensate for model parameter differences to some extent.

(4) If possible, larger VIT models can be pre-trained using the DINO training framework, as the largest VIT-dino officially released by DINO is only base, with no large version available. In the future, pre-training can be done on Google Landmarks v2+ImageNet+private datasets.

Deploying the trained DINO model on SageMaker requires providing an inference script file inference.py. The main code is as follows:

...def predict_fn(single_data, model):"""
    Predict a result using a single data
    :param single_data: a single numpy array for an image
    :type single_data: numpy.array
    :param model: the loaded model
    :type model:
    :return:an object with prediction value
    :rtype: object
    """
    imsize = 648
    transform = pth_transforms.Compose([
            pth_transforms.Resize((imsize, imsize), interpolation=3),
            pth_transforms.ToTensor(),
            pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),])
    image = transform(single_data)try:
        output = model(image[None].cuda())# First, move the tensor to CPU
        cpu_tensor = output.cpu()# Then convert to NumPy array
        numpy_array = cpu_tensor.detach().numpy()return numpy_array[0]except Exception as e:raise e

def input_fn(input_data, request_content_type):#  The request_body is coming 1 by 1"""An input_fn that loads a pickled tensor"""if request_content_type == "application/json":try:
            json_request = json.loads(input_data)
            file_byte_string = s3_client.get_object(
                Bucket=json_request["bucket"], Key=json_request["file_name"])["Body"].read()
            im = Image.open(io.BytesIO(file_byte_string))
            im = im.convert("RGB")return imexcept Exception as e:raise eelif request_content_type == "application/x-image":
        im = Image.open(BytesIO(input_data))
        im = im.convert("RGB")return imelse:# Handle other content-types here or raise an Exception# if the content type is not supported.raise Exception("Unsupported content type")def model_fn(model_dir):
    pretrained_weights = os.path.join(model_dir, "checkpoint.pth")print(os.path.abspath(os.path.join(model_dir, "config.json")))# Open the file and load its contents
    config_path = os.path.join(model_dir, "config.json")with open(config_path, "r") as config_file:
        model_config = json.load(config_file)print("loading model info: %s", model_config)# load pretrained weightsif os.path.isfile(pretrained_weights):
        model = vits.dict[model_config["arch"]](
            patch_size=model_config["patch_size"],
            drop_path_rate=model_config["drop_path_rate"],  # stochastic depth)
        state_dict = torch.load(pretrained_weights, map_location="cpu")
        state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
        msg = model.load_state_dict(state_dict, strict=False)print("Pretrained weights found at {} and loaded with msg: {}".format(
                pretrained_weights, msg))else:print("Since no pretrained weights have been provided, we load pretrained DINO weights on Google Landmark v2.")
        model = torch.hub.load("facebookresearch/xcit:main", "vit_small", pretrained=False)
        model.load_state_dict(
            torch.hub.load_state_dict_from_url(
                url="https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth"))
    model = model.cuda()
    model.eval()return model...

4. Vector Search

4.1 Technical Challenges

(1) Used for product recall, not image recall: The ultimate goal is to find products based on images, not simply find similar images.

(2) Need to support efficient vector retrieval from vector storage: The vector database needs to support fast vector retrieval at the million-scale level, and search results should provide unique product identifiers (such as product codes).

4.2 Solution

Use OpenSearch to store both image vector data and product codes, so that product codes can be obtained simultaneously after vector similarity comparison. We use the Faiss-HNSW algorithm as the retrieval algorithm, and for similarity calculation, we use the Cosine function that matches the model's fine-tuning phase. The key considerations are shown in the figure below:

图6.png

OpenSearch provides multiple algorithm choices. Based on the comparison in the figure below, we ultimately chose FAISS-HNSW as the vector indexing.

图7.png

In summary, this involves multiple aspects including image processing, object detection, image segmentation, embedding, and vector search, requiring solutions to challenges in data, model accuracy, deployment environment, and search results. Through appropriate data preprocessing, model selection, and system design, an efficient image-based product retrieval system can be built.

V. Experimental Test Results

图8.png

The figure above shows the CMC (Cumulative Match Characteristic) test results. The x-axis rank n represents the top n products retrieved, and the y-axis represents the probability that the target product is among the top n products retrieved. Our test product library contains around 6,000 categories, and user images are all real-world scene images. We can see that 75% of the images are recalled at rank 1, and 86% of the correct product images are recalled within the top 5 positions. This retrieval precision meets the client's requirement that the probability of having the target product in the top 5 products reaches 85%. Business personnel have confirmed that the search can automatically ignore background influence, and the distinction and recognition of details have approached or reached human-level performance.

VI. Conclusion

This article demonstrates how to train models using clothing and footwear products and perform image searches by detecting and cropping target items using GroundingDINO. This approach meets enterprise-level, especially vertical industry requirements for high-precision search, helping to improve user search experience. This solution can also be extended to other vertical industries such as e-commerce, gaming, short videos, healthcare, and manufacturing.

References

(1) Paper: "Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection," Authors: Shilong Liu, Zhaoyang Zeng, Tianhe Ren, Feng Li, Hao Zhang, Jie Yang, Qing Jiang, Chunyuan Li, Jianwei Yang, Hang Su, Jun Zhu, Lei Zhang. Link: https://arxiv.org/abs/2303.05499

(2) Latest DINO API of DINO-X Platform: https://cloud.deepdataspace.com

(3) Grounding DINO Playground: https://cloud.deepdataspace.com/playground/grounding_dino