以图搜图实现

前言

现在很多电商都实现了以图搜图的效果,所以我们也需要了,在此记录一下。目前想要免费实现以图搜图,所以使用的库是开源库:AIAS。话不多说,直接上实现流程。

配置依赖

<dependency>
			<groupId>org.elasticsearch.client</groupId>
			<artifactId>elasticsearch-rest-high-level-client</artifactId>
			<version>7.14.0</version>
		</dependency>
		<dependency>
			<groupId>org.elasticsearch.client</groupId>
			<artifactId>elasticsearch-rest-client</artifactId>
			<version>7.15.0</version>
		</dependency>
		<dependency>
			<groupId>org.elasticsearch.plugin</groupId>
			<artifactId>elasticsearch-plugin-hnswlib</artifactId>
		</dependency>
		<dependency>
			<groupId>ai.djl</groupId>
			<artifactId>api</artifactId>
			<version>0.17.0</version>
		</dependency>
		<dependency>
			<groupId>ai.djl</groupId>
			<artifactId>basicdataset</artifactId>
			<version>0.17.0</version>
		</dependency>
		<dependency>
			<groupId>ai.djl</groupId>
			<artifactId>model-zoo</artifactId>
			<version>0.17.0</version>
		</dependency>
		<!-- Pytorch -->
		<dependency>
			<groupId>ai.djl.pytorch</groupId>
			<artifactId>pytorch-engine</artifactId>
			<version>0.17.0</version>
		</dependency>
		<dependency>
			<groupId>ai.djl.pytorch</groupId>
			<artifactId>pytorch-model-zoo</artifactId>
			<version>0.17.0</version>
		</dependency>

编写工具类

import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import lombok.extern.slf4j.Slf4j;

import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;

@Slf4j
public class ImageUtil {

    public static String filePath = "D:/xxxx/code_workspace/xxxx";

    public static float[] featureExtraction(String url){
        float[] feature = null;
        try {
//            Path imageFile = Paths.get("src/test/resources/car1.png");
            Path imageFile = Paths.get(filePath+url);
            Image img = ImageFactory.getInstance().fromFile(imageFile);
            Criteria<Image, float[]> criteria = new ImageEncoderModel().criteria();
            ZooModel model = ModelZoo.loadModel(criteria);
            Predictor<Image, float[]> predictor = model.newPredictor();
            feature = predictor.predict(img);
            System.out.println(feature.length);
            log.info(Arrays.toString(feature));
        }catch (Exception e){
            e.printStackTrace();
        }
        return feature;
    }

}

创建ES向量索引

PUT test_img
{
  "settings": {
    "index.codec": "proxima",
    "index.vector.algorithm": "hnsw",  
    "index.number_of_replicas":1,  
    "index.number_of_shards":3 
  },
  "mappings": {
    "properties": {
      "feature": {
        "type": "proxima_vector", 
        "dim": 512, 
        "vector_type": "float"
      },
      "goods_id":{
        "type": "long"
      }
    }
  }
}

定义ES实体类

import lombok.Data;

@Data
public class Es_GoodsFeature {
    Long goods_id;
    float[] feature;
    String goods_name;
}

提取图片特征同步到ES

		String coverImg = goods.getCover_img();
        float[] feature = ImageUtil.featureExtraction(coverImg);
        if(feature == null) return;
        Es_GoodsFeature esGoodsFeature=new Es_GoodsFeature();
        esGoodsFeature.setGoods_id(goods.getGoods_id());
        esGoodsFeature.setFeature(goods.getFeature());
        esGoodsFeature.setGoods_name(goods.getGoods_name());
        //index_name为索引名称;type_name为类型名称,7.0及以上版本必须为_doc;doc_id为文档的id。
        IndexRequest indexRequest = new IndexRequest(esFeatureConfig.getIndex(), "_doc", esGoods.getGoods_id().toString()).source(JSON.toJSONString(esGoods), XContentType.JSON);
        // 同步执行,并使用自定义RequestOptions(COMMON_OPTIONS)。
        IndexResponse indexResponse = esFeatureClient.index(indexRequest, RequestOptions.DEFAULT);

搜索图片

import com.utils.ImageUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
@Slf4j
public class TestImage {
    public static String host = "127.0.0.1";
    public static int port = 9200;
    public static String username = "test_feature";
    public static String password = "xxxx";
    public static String index = "test_img";
    @Test
    public void TestSearch(){
        try{
            String imgUrl = "/img/20230411/search.png";
            float[] feature = ImageUtil.featureExtraction(imgUrl);

            final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
            //访问用户名和密码为您创建阿里云Elasticsearch实例时设置的用户名和密码,也是Kibana控制台的登录用户名和密码。
            credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(username, password));

            // 通过builder创建rest client,配置http client的HttpClientConfigCallback。
            // 单击所创建的Elasticsearch实例ID,在基本信息页面获取公网地址,即为ES集群地址。
            RestClientBuilder restClientBuilder = RestClient.builder(new HttpHost(host, port))
                    .setHttpClientConfigCallback(new RestClientBuilder.HttpClientConfigCallback() {
                        @Override
                        public HttpAsyncClientBuilder customizeHttpClient(HttpAsyncClientBuilder httpClientBuilder) {
                            return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
                        }
                    });
            RestHighLevelClient client = new RestHighLevelClient(restClientBuilder);

            SearchRequest searchRequest = new SearchRequest(index);
            //ES执行语句
            String hnswQuery = "{\n" +
                    "    \"hnsw\": {  \n" +
                    "      \"feature\": {\n" +
                    "        \"vector\": " + Arrays.toString(feature) + ", \n" +
                    "        \"size\": 10 \n" +
                    "      }\n" +
                    "    }\n" +
                    "  }";
            QueryBuilder queryBuilder = QueryBuilders.wrapperQuery(hnswQuery);
            SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
            searchSourceBuilder.query(queryBuilder)
                    .fetchSource(new String[]{"goods_id","goods_name"},null);//指定返回的字段
//                    .minScore(0.9f);//匹配度
            log.info(Arrays.toString(feature));
            searchRequest.source(searchSourceBuilder);
            log.info(searchSourceBuilder.toString());
            SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);
            SearchHits hits = searchResponse.getHits();
            for (SearchHit hit : hits){
                String id = hit.getId();
                double score = hit.getScore();
                System.out.println("id: "+hit.getId()+" 商品名:"+hit.getSourceAsMap().get("goods_name")+" 相似度:"+hit.getScore());
            }
        }catch (Exception e){
            e.printStackTrace();
        }
    }
}

总结

以图搜图没那么难,找对方式就可以了。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值