前言
现在很多电商都实现了以图搜图的效果,所以我们也需要了,在此记录一下。目前想要免费实现以图搜图,所以使用的库是开源库:AIAS。话不多说,直接上实现流程。
配置依赖
<dependency>
<groupId>org.elasticsearch.client</groupId>
<artifactId>elasticsearch-rest-high-level-client</artifactId>
<version>7.14.0</version>
</dependency>
<dependency>
<groupId>org.elasticsearch.client</groupId>
<artifactId>elasticsearch-rest-client</artifactId>
<version>7.15.0</version>
</dependency>
<dependency>
<groupId>org.elasticsearch.plugin</groupId>
<artifactId>elasticsearch-plugin-hnswlib</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.17.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>0.17.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>0.17.0</version>
</dependency>
<!-- Pytorch -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.17.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>0.17.0</version>
</dependency>
编写工具类
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import lombok.extern.slf4j.Slf4j;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
@Slf4j
public class ImageUtil {
public static String filePath = "D:/xxxx/code_workspace/xxxx";
public static float[] featureExtraction(String url){
float[] feature = null;
try {
// Path imageFile = Paths.get("src/test/resources/car1.png");
Path imageFile = Paths.get(filePath+url);
Image img = ImageFactory.getInstance().fromFile(imageFile);
Criteria<Image, float[]> criteria = new ImageEncoderModel().criteria();
ZooModel model = ModelZoo.loadModel(criteria);
Predictor<Image, float[]> predictor = model.newPredictor();
feature = predictor.predict(img);
System.out.println(feature.length);
log.info(Arrays.toString(feature));
}catch (Exception e){
e.printStackTrace();
}
return feature;
}
}
创建ES向量索引
PUT test_img
{
"settings": {
"index.codec": "proxima",
"index.vector.algorithm": "hnsw",
"index.number_of_replicas":1,
"index.number_of_shards":3
},
"mappings": {
"properties": {
"feature": {
"type": "proxima_vector",
"dim": 512,
"vector_type": "float"
},
"goods_id":{
"type": "long"
}
}
}
}
定义ES实体类
import lombok.Data;
@Data
public class Es_GoodsFeature {
Long goods_id;
float[] feature;
String goods_name;
}
提取图片特征同步到ES
String coverImg = goods.getCover_img();
float[] feature = ImageUtil.featureExtraction(coverImg);
if(feature == null) return;
Es_GoodsFeature esGoodsFeature=new Es_GoodsFeature();
esGoodsFeature.setGoods_id(goods.getGoods_id());
esGoodsFeature.setFeature(goods.getFeature());
esGoodsFeature.setGoods_name(goods.getGoods_name());
//index_name为索引名称;type_name为类型名称,7.0及以上版本必须为_doc;doc_id为文档的id。
IndexRequest indexRequest = new IndexRequest(esFeatureConfig.getIndex(), "_doc", esGoods.getGoods_id().toString()).source(JSON.toJSONString(esGoods), XContentType.JSON);
// 同步执行,并使用自定义RequestOptions(COMMON_OPTIONS)。
IndexResponse indexResponse = esFeatureClient.index(indexRequest, RequestOptions.DEFAULT);
搜索图片
import com.utils.ImageUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
@Slf4j
public class TestImage {
public static String host = "127.0.0.1";
public static int port = 9200;
public static String username = "test_feature";
public static String password = "xxxx";
public static String index = "test_img";
@Test
public void TestSearch(){
try{
String imgUrl = "/img/20230411/search.png";
float[] feature = ImageUtil.featureExtraction(imgUrl);
final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
//访问用户名和密码为您创建阿里云Elasticsearch实例时设置的用户名和密码,也是Kibana控制台的登录用户名和密码。
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(username, password));
// 通过builder创建rest client,配置http client的HttpClientConfigCallback。
// 单击所创建的Elasticsearch实例ID,在基本信息页面获取公网地址,即为ES集群地址。
RestClientBuilder restClientBuilder = RestClient.builder(new HttpHost(host, port))
.setHttpClientConfigCallback(new RestClientBuilder.HttpClientConfigCallback() {
@Override
public HttpAsyncClientBuilder customizeHttpClient(HttpAsyncClientBuilder httpClientBuilder) {
return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
}
});
RestHighLevelClient client = new RestHighLevelClient(restClientBuilder);
SearchRequest searchRequest = new SearchRequest(index);
//ES执行语句
String hnswQuery = "{\n" +
" \"hnsw\": { \n" +
" \"feature\": {\n" +
" \"vector\": " + Arrays.toString(feature) + ", \n" +
" \"size\": 10 \n" +
" }\n" +
" }\n" +
" }";
QueryBuilder queryBuilder = QueryBuilders.wrapperQuery(hnswQuery);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(queryBuilder)
.fetchSource(new String[]{"goods_id","goods_name"},null);//指定返回的字段
// .minScore(0.9f);//匹配度
log.info(Arrays.toString(feature));
searchRequest.source(searchSourceBuilder);
log.info(searchSourceBuilder.toString());
SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);
SearchHits hits = searchResponse.getHits();
for (SearchHit hit : hits){
String id = hit.getId();
double score = hit.getScore();
System.out.println("id: "+hit.getId()+" 商品名:"+hit.getSourceAsMap().get("goods_name")+" 相似度:"+hit.getScore());
}
}catch (Exception e){
e.printStackTrace();
}
}
}
总结
以图搜图没那么难,找对方式就可以了。