DDS-LOGO

学术分享:DINO 模型如何创新垂直电商图像搜索

本文转自 AWS 团队于 2024 年 10 月发布的基于 DINO 模型构建的垂直电商图像搜索解决方案《垂直电商图像搜索再升级:DINO 模型带来精准匹配体验》,原文链接:https://aws.amazon.com/cn/blogs/china/vertical-e-commerce-image-search-upgraded-dino-model-brings-accurate-matching-experience/, 作者江炳坤、姬军翔、吕浩然、尹振宇、洪丹、华成。

我们在网上或者线下购物的时候,有时候很难找到合适的词语来描述你正在寻找的东西。俗话说:”一张图片胜过千言万语”,通常情况下,展示实物或图片比用文字描述物品更容易,尤其是在使用 APP 或者小程序寻找产品的时候。

然而,构建一个电商级图片搜索需要面临诸多挑战,例如:

(1)对于私域,或品类电商来说,其部分产品及图片区分度较低;

(2)用户输入的图片在拍摄角度,镜头变形,光照,背景干扰等方面和产品图片的区别较大;

(3)系统对响应时间的要求比较苛刻等。

在这篇文章中,我们将介绍如何从头构建一个鞋服类的垂直模型,从而实现低延迟,高精度的图片搜索解决方案。

一、业务背景介绍

图片搜索可以提高零售业务和电子商务中的客户参与度,尤其是对服装类(衣服、裤子、鞋、服装饰品等)零售商而言。服装类是在图片搜索中最重要的产品类型。调研报告显示有 36% 的消费者曾经使用过图片搜索,有 74% 的消费者认为传统的文字搜索很难帮助他们找到正确的产品。

由于行业的特性,服装类大多具有非常高的相似度,比如运动鞋和衣服,大多数鞋的形状和风格非常类似,需要通过非常细粒度的特征来进行识别。比如下面不一样型号的鞋子,会非常相似。

图1.png

在这篇文章中,您将学习如何构建一个类似的产品目录相似度搜索解决方案。该方案主要集成 Amazon SageMaker 和亚马逊关系数据库服务 Amazon Aurora MySQL,向量数据存储 Amazon OpenSearch。

二、业务需求分解

1. 基于对象的高效搜索

当用户输入的图片中同时存在多个商品或目标时,允许用户在图像中搜索特定的对象或物品,这样他们能够只搜索感兴趣的产品,而不是搜索整个图像。这种功能可以提高搜索效率,让用户更快地找到所需内容。

2. 自动产品识别

系统能够自动识别图像中的产品。将来,这项功能可以与电子商务平台集成,根据识别出的产品向用户推荐相关商品,促进销售。

3. 搜索准确性

用户搜索的图片和索引库中待比对的图片在不同角度、不同光线条件下拍摄,系统在万级别品类下,Top5 的召回也能够达到 85% 以上的准确率,将产品与相关图像正确匹配,这是基于对图像视觉特征的分析。高准确度可以确保搜索结果的相关性。

4. 安全和隐私

系统可以进行私有化部署,并确保符合相关的隐私法规和合规要求。

5. 索引和存储

系统需要高效地索引和存储超过 100 万张图像数据,以及相关的元数据,如标签、描述和其他相关信息,以支持快速搜索和检索。

三、整体方案

1. 参考架构图

图2.png

2. 方案步骤

2.1 离线处理(白色线条部分)

(1)启动一个 Notebook 读取 S3 里面的所有的图片;

(2)调用 Bedrock 进行图片打标处理,用于过滤用来训练的数据;

(3)打完标记的结果放到 Aurora Mysql 里面保存;

(4)启动 Sagemaker 的模型训练节点,使用过滤后的训练数据进行训练。将训练完后的 embedding 模型部署到 Sagemaker;

(5)调用 embedding 模型对现有的所有产品图片进行 embedding,结果存入 OpenSearch;

2.2 实时处理(黄色线条部分)

(1)前端通过 Cloudfront 加载页面和产品图片;

(2)Cloudfront 读取 S3 中的静态数据;

(3)当上传图片的时候,Cloudfront 会将请求转发到 API Gateway;

(4)API Gateway 将请求转发到 EC2;

(5)EC2 将图片发送到 Lambda;

(6)Lambda 将图片发送到 GroundingDINO 进行目标检测。如果图片中没有任何目标物品,则返回前端;如果有多个目标物品,则将检测到的目标物品的坐标返回给前端,以允许用户进行物品选择;如果只有一个目标物品,或者用户已经选择了目标物品,则根据 GroundingDINO 返回的长方形框剪切出目标图片,进入下一步;

(7)将剪切出目标图片通过 Lambda;

(8)Lambda 调用 embedding 模型获取向量;

(9)通过向量查询 OpenSearch 获取 top5 的产品代码;

(10)通过产品代码查询 Aurora 得到产品详细数据并返回前端。

四、技术难点以及解决思路

1. 图像预处理

1.1 技术难点

(1)存在不适合训练的图片:某些图片可能只显示产品的部分视角(如鞋底),这对于训练模型来说可能不太合适。

(2)图片质量不一致,角度不同:由于图像来源的多样性,图像质量和拍摄角度可能存在差异,这会影响模型的训练效果。

1.2 解决方案

参照下图,我们利用大语言模型最新的多模态功能,输入图片,让模型对图片的进行图片标注,在我们的场景中,我们设计了如下标签体系。“是否出现模特”,“模特人数”,“是否真实世界的场景”,“是否穿在模特身上”,“拍摄角度”,“局部还是整体”等,通过这些图片,我们可以过滤掉比如鞋底这类对训练和搜索都没有帮助的图片。

同时,我们也利用这些标签进行训练集和测试集的划分。真实世界场景的图片都被划分到测试集。

图3.png

2. 目标检测和分割

2.1 技术难点

(1)用户使用的搜索图片无法做预先的限定,会出现不包含任何产品和包含多个产品的情况。如何确定图像中的目标是否为公司销售的产品类别:需要一种方法来识别图像中的目标是否属于公司销售的产品范围。

(2)如果检测到多个产品,需要用户选择:当图像中包含多个产品时,需要提供一种机制让用户选择感兴趣的产品。

2.2 解决方案

使用 Grounding DINO 进行目标检测,针对鞋子,帽子,裤子等。然后直接使用代码剪切出对应的长方形块(这里保留了长方形块里面的所有元素,包括背景。最后没有使用 SAM 切割出不规则的物品,原因是我们发现,仅对目标图片做方框的截取即可,使用 SAM 做像素级分割,反而降低了模型的效果 )。

图4.png

首先我们先构建模型压缩包,并上传至 S3 存储桶中,如下图所示:

import boto3
import sagemaker
from sagemaker import serializers, deserializers
from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment
s3_model_prefix = "east-ai-models/grounded-sam"
!touch dummy
!rm -f model.tar.gz
!tar czvf model.tar.gz dummy
s3_model_artifact = sess.upload_data("model.tar.gz", bucket, s3_model_prefix)print(f"S3 Code or Model tar uploaded to --- > {s3_model_artifact}")
!rm -f dummy

接下来我们准备创建模型所需要的代码,以下代码均在本地“code”路径下:

endpoint_name ="grounded-sam"#%%
framework_version = '2.3.0'
py_version = 'py311'
instance_type = "ml.g4dn.xlarge"
endpoint_name ="grounded-sam"
model = PyTorchModel(
    model_data = s3_model_artifact,
    entry_point = 'inference.py',
    source_dir = "./code/",
    role = role,
    framework_version = framework_version,    py_version = py_version,)print("模型部署过程大约需要 7~8 分钟,请等待" + "."*20)
model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,)print("模型部署已完成,可以继续执行后续步骤" + "."*20)

准备自定义推理脚本 clip_inference.py。我们在 model_fn 中进行模型加载,在 predict_fn 定义推理逻辑,核心代码如下:

import os
import io
from PIL import Image
import numpy as np
import torch
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from groundingdino.util.inference import predict
import groundingdino.datasets.transforms as T
from huggingface_hub import hf_hub_download
import json
import boto3
import uuid
import math

def get_detection_boxes(image_source: Image, model: dict, prompt: str = "clothes . pants . hats . shoes") -> (list, list, list):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    box_treshold = 0.3
    text_treshold = 0.25
    transform = T.Compose([
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])
    image_transformed, _ = transform(image_source, None)try:
        boxes, logits, phrases = predict(
            model=model['dino'],
            image=image_transformed,
            caption=prompt,
            box_threshold=box_treshold,
            text_threshold=text_treshold,
            device='cuda')except Exception as e:print(e)return
    boxes_list = boxes.numpy().tolist()
    logits_list = logits.numpy().tolist()return boxes_list, logits_list, phrases

def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
    args = SLConfig.fromfile(cache_config_file)
    model = build_model(args)
    args.device = device
    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location=device)
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    model.cuda()
    _ = model.eval()return model

def model_fn(model_dir):
    ckpt_repo_id = "ShilongLiu/GroundingDINO"
    ckpt_filenmae = "groundingdino_swint_ogc.pth"
    ckpt_config_filename = "GroundingDINO_SwinT_OGC.cfg.py"
    model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
    model_dic = {'dino': model, 'sam': ''}return model_dic

def save_file_to_s3(mask_image, file_extension, output_mask_image_dir: str):# 图片存储到s3......return mask_image_output

def crop_images_from_boxes(image_source: Image, boxes: list, scale_factor: float = 1.0, target_size: int = 400) -> list:
    cropped_images = []
    width, height = image_source.size
for box in boxes:
        cx, cy, w, h = [coord * scale_factor for coord in box]# 计算边界框的左上角和右下角坐标
        x1 = max(0, math.floor((cx - w / 2) * width))
        y1 = max(0, math.floor((cy - h / 2) * height))
        x2 = min(width, math.ceil((cx + w / 2) * width))
        y2 = min(height, math.ceil((cy + h / 2) * height))# 如果边界框在图像范围内,则裁剪图像if x2 > x1 and y2 > y1:
            cropped_image = image_source.crop((x1, y1, x2, y2))# 调整裁剪后图像的大小
            cropped_width, cropped_height = cropped_image.size# 等比例调整到目标尺寸
            scale = min(target_size / cropped_width, target_size / cropped_height)
            new_width = int(cropped_width * scale)
            new_height = int(cropped_height * scale)
            cropped_image = cropped_image.resize((new_width, new_height), resample=Image.BICUBIC)
            cropped_images.append(cropped_image)return cropped_images

def predict_fn(input_data, model):print("=================Dino detect start=================")try:
        file_extension = os.path.splitext(input_data['input_image'])[1][1:].lower()
        dir_lst = input_data['input_image'].split('/')
        s3_client = boto3.client('s3')
        s3_response_object = s3_client.get_object(Bucket=dir_lst[2], Key='/'.join(dir_lst[3:]))
        img_bytes = s3_response_object['Body'].read()
        image_source = Image.open(io.BytesIO(img_bytes)).convert("RGB")if 'boxes' not in input_data:
            prompt = input_data['prompt']
            boxes, logits, phrases = get_detection_boxes(image_source, model, prompt)if len(boxes) == 0:return {"error_message": "The image does not contain any object needed"}elif len(boxes) > 1:return {"boxes": boxes, "file_type": file_extension, "logits": logits, "phrases": phrases}
        boxes = [input_data['boxes']] if 'boxes' in input_data else boxes
        cropped_images = crop_images_from_boxes(image_source, boxes)
        mask_image_output = save_file_to_s3(cropped_images[0], file_extension, input_data['output_mask_image_dir'])return {"mask_image_output": mask_image_output}except Exception as e:print(e)

3. Embedding 模型

3.1 技术难点

传统的图片 Embedding 模型在用作向量召回时往往存在如下问题

(1)缺乏标注的图片:训练模型需要大量已标注的图像数据,但获取这些标注成本过高,可能存在困难。

(2)模型需要高精度以进行细粒度比较:为了准确匹配相似产品,嵌入模型需要具有足够的精度来捕捉细微的差异。

(3)模型输出的 Embedding 的鲁棒性不足:会受到背景,衣物形变,拍摄角度,光线等因素的较大影响。

(4)需要私有部署选项以保证安全和隐私:出于安全和隐私考虑,可能需要在本地私有环境中部署模型。

(5)模型应该可定制和可扩展:为了满足不同的需求,模型应该具有一定的定制和扩展能力。

3.2 解决方案

先用基于 DINO+VIT 的模型在私有产品图片数据上进行预训练,这个阶段无需进行标注,DINO 就可以自行关注到图片中的主体,而不容易受到背景的干扰。在第二阶段,我们采用对比学习或者分类的方式对模型进行 Finetune 从而进一步提升召回能力。下图可视化了 DINO 模型的注意力层,展示其相对于传统模型的优点,我们可以看到 DINO 这一列中展示的模型注意力可以剥离背景的干扰因素,而传统的有监督算法的注意力没有准确的捕捉到图片中的主体。

图5.png

在具体的算法开发过程中,我们评估了 DINO 和 DINO V2,Triplet Loss 和 Cross Entropy Loss,也对比了 VIT 和 CNN,在大量实验的基础上,得到的最终的结论如下:

(1)Triplet loss,目前看下来经济性远不如 cross entropy loss,同样的训练轮次完全不收敛(个位数的 mAP),原因是 cross entropy loss 训练过程中一次梯度更新优化的是整个样本分布,而 triplet loss 一次梯度更新仅仅是优化采样到的正负样本,训练效率完全不是一个等级,但是 triplet loss 这种直接优化特征的模式其实更加适配向量匹配任务,可能需要更大的 batch size 或者更细致的超参数调节,加上更完备的难负样本挖掘。

(2)DINOv2(即加入了 MAE 损失的 DINO)在此场景下毫无意外地比 DINO 差,甚至 large 和 giant 版本的 VIT-dinov2 都比不过 Base 的 VIT-dino,目前的猜测是由于重建类的损失(MAE 损失)并不适配判别场景,此种场景下还是判别损失(Cross Entorpy Loss)更加合适,关注的特征也更加低频,提取到的特征更加适合做判别任务。

(3)DINOv1 是目前最适合做向量搜索的预训练算法,这种预训练方法甚至可以一定程度上弥补模型参数量的差距。

(4)有条件的话可以用 DINO 的训练框架预训练更大的 VIT 模型,因为目前 DINO 官方给出的最大的 VIT-dino 只到 base,并没有 large 版本放出,后续可以在 Google Landmarks v2+ImageNet+私有数据集上进行预训练。

将训练好的 DINO 模型部署在 SageMaker 上,需要提供推理脚本文件 inference.py。其中的主要代码如下:

...def predict_fn(single_data, model):"""
    Predict a result using a single data
    :param single_data: a single numpy array for an image
    :type single_data: numpy.array
    :param model: the loaded model
    :type model:
    :return:an object with prediction value
    :rtype: object
    """
    imsize = 648
    transform = pth_transforms.Compose([
            pth_transforms.Resize((imsize, imsize), interpolation=3),
            pth_transforms.ToTensor(),
            pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),])
    image = transform(single_data)try:
        output = model(image[None].cuda())# First, move the tensor to CPU
        cpu_tensor = output.cpu()# Then convert to NumPy array
        numpy_array = cpu_tensor.detach().numpy()return numpy_array[0]except Exception as e:raise e

def input_fn(input_data, request_content_type):#  The request_body is coming 1 by 1"""An input_fn that loads a pickled tensor"""if request_content_type == "application/json":try:
            json_request = json.loads(input_data)
            file_byte_string = s3_client.get_object(
                Bucket=json_request["bucket"], Key=json_request["file_name"])["Body"].read()
            im = Image.open(io.BytesIO(file_byte_string))
            im = im.convert("RGB")return imexcept Exception as e:raise eelif request_content_type == "application/x-image":
        im = Image.open(BytesIO(input_data))
        im = im.convert("RGB")return imelse:# Handle other content-types here or raise an Exception# if the content type is not supported.raise Exception("Unsupported content type")def model_fn(model_dir):
    pretrained_weights = os.path.join(model_dir, "checkpoint.pth")print(os.path.abspath(os.path.join(model_dir, "config.json")))# Open the file and load its contents
    config_path = os.path.join(model_dir, "config.json")with open(config_path, "r") as config_file:
        model_config = json.load(config_file)print("loading model info: %s", model_config)# load pretrained weightsif os.path.isfile(pretrained_weights):
        model = vits.dict[model_config["arch"]](
            patch_size=model_config["patch_size"],
            drop_path_rate=model_config["drop_path_rate"],  # stochastic depth)
        state_dict = torch.load(pretrained_weights, map_location="cpu")
        state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
        msg = model.load_state_dict(state_dict, strict=False)print("Pretrained weights found at {} and loaded with msg: {}".format(
                pretrained_weights, msg))else:print("Since no pretrained weights have been provided, we load pretrained DINO weights on Google Landmark v2.")
        model = torch.hub.load("facebookresearch/xcit:main", "vit_small", pretrained=False)
        model.load_state_dict(
            torch.hub.load_state_dict_from_url(
                url="https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth"))
    model = model.cuda()
    model.eval()return model...

4. 向量搜索

4.1 技术难点****

(1)用于产品召回,而非图像召回:最终目标是根据图像找到相应的产品,而不是简单地找到相似图像。

(2)需要支持从向量存储中高效检索向量:向量数据库需要能够支撑百万级的快速向量检索,且搜索结果应该能够提供产品的唯一标识符(如产品代码)。

4.2 解决方案

使用 OpenSearch 同时存储图片的向量数据和产品的代码,这样在做向量相似度对比后,可以同时获取产品代码。同时使用 Faiss-HNSW 算法作为检索算法,同时相似度的计算我们使用了和模型 Finetune 阶段相匹配的 Cosine 函数。核心的考虑点如下图:

图6.png

OpenSearch 提供了多种算法选择,通过下图的对比,我们最终选择了 FAISS-HNSW 作为向量索引算法。

图7.png

总的来说,这里涉及图像处理、目标检测、图像分割、embedding 和向量搜索等多个方面,需要解决数据、模型精度、部署环境和搜索结果等多个挑战。通过合理的数据预处理、模型选择和系统设计,可以构建一个高效的基于图像的产品检索系统。

五、实验测试结果

图8.png

上图是 CMC(Cumulative Match Characteristic)的测试结果,横坐标 rank n 代表检索出的前 n 个产品,纵坐标是检索出的前 n 个产品里面有目标产品的概率。我们的测试产品库中包含 6000 个左右的品类,用户图片都是真实世界场景的图片,可以看到有 75% 的图片在 rank 1 的位置召回,86% 的正确产品图片都在前 5 的位置被召回。这个检索的精度,满足了客户要求的前 5 个产品里面有目标产品的概率达到 85% 的要求。并且经过业务人员的确认,搜索可以自动忽略背景的影响,对于细节的区别和辨认也已经接近或者达到人类水平。

六、结论

本文通过使用服装鞋类商品进行模型训练,同时通过 GroundingDINO 进行目标物品检测和剪切的方式对图片进行搜索,这种方式满足企业级的,特别是垂直行业的高精度搜索。有助于更好地提升用户的搜索体验。该方案也可以拓展到其他的垂直行业使用,如电商、游戏、短视频,医疗、制造业等。

参考资料

(1)论文《Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection》,作者:Shilong Liu, Zhaoyang Zeng, Tianhe Ren, Feng Li, Hao Zhang, Jie Yang, Qing Jiang, Chunyuan Li, Jianwei Yang, Hang Su, Jun Zhu, Lei Zhang。链接:https://arxiv.org/abs/2303.05499

(2)DINO-X 开放平台最新的 DINO API :https://cloud.deepdataspace.com/

(3)Grounding DINO Playground:https://cloud.deepdataspace.com/playground/grounding_dino