All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.milvus.orm.iterator.IteratorAdapterV2 Maven / Gradle / Ivy

There is a newer version: 2.4.8
Show newest version
package io.milvus.orm.iterator;

import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.exception.ParamException;
import io.milvus.grpc.DataType;
import io.milvus.grpc.PlaceholderType;
import io.milvus.param.MetricType;
import io.milvus.param.collection.FieldType;
import io.milvus.param.dml.SearchIteratorParam;
import io.milvus.param.dml.QueryIteratorParam;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.vector.request.QueryIteratorReq;
import io.milvus.v2.service.vector.request.SearchIteratorReq;
import io.milvus.v2.service.vector.request.data.BaseVector;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.SortedMap;

public class IteratorAdapterV2 {
    public static QueryIteratorParam convertV2Req(QueryIteratorReq queryIteratorReq) {
        return QueryIteratorParam.newBuilder()
                .withDatabaseName(queryIteratorReq.getDatabaseName())
                .withCollectionName(queryIteratorReq.getCollectionName())
                .withPartitionNames(queryIteratorReq.getPartitionNames())
                .withExpr(queryIteratorReq.getExpr())
                .withOutFields(queryIteratorReq.getOutputFields())
                .withConsistencyLevel(ConsistencyLevelEnum.valueOf(queryIteratorReq.getConsistencyLevel().name()))
                .withOffset(queryIteratorReq.getOffset())
                .withLimit(queryIteratorReq.getLimit())
                .withIgnoreGrowing(queryIteratorReq.isIgnoreGrowing())
                .withBatchSize(queryIteratorReq.getBatchSize())
                .build();
    }
    public static SearchIteratorParam convertV2Req(SearchIteratorReq searchIteratorReq) {
        MetricType metricType = MetricType.None;
        if (searchIteratorReq.getMetricType() != IndexParam.MetricType.INVALID) {
            metricType = MetricType.valueOf(searchIteratorReq.getMetricType().name());
        }

        SearchIteratorParam.Builder builder = SearchIteratorParam.newBuilder()
                .withDatabaseName(searchIteratorReq.getDatabaseName())
                .withCollectionName(searchIteratorReq.getCollectionName())
                .withPartitionNames(searchIteratorReq.getPartitionNames())
                .withVectorFieldName(searchIteratorReq.getVectorFieldName())
                .withMetricType(metricType)
                .withTopK(searchIteratorReq.getTopK())
                .withExpr(searchIteratorReq.getExpr())
                .withOutFields(searchIteratorReq.getOutputFields())
                .withRoundDecimal(searchIteratorReq.getRoundDecimal())
                .withParams(searchIteratorReq.getParams())
                .withIgnoreGrowing(searchIteratorReq.isIgnoreGrowing())
                .withBatchSize(searchIteratorReq.getBatchSize());

        if (searchIteratorReq.getConsistencyLevel() != null) {
            builder.withConsistencyLevel(ConsistencyLevelEnum.valueOf(searchIteratorReq.getConsistencyLevel().name()));
        }

        List vectors = searchIteratorReq.getVectors();
        PlaceholderType plType = vectors.get(0).getPlaceholderType();
        for (BaseVector vector : vectors) {
            if (vector.getPlaceholderType() != plType) {
                throw new ParamException("Different types of target vectors in a search request is not allowed.");
            }
        }

        switch (plType) {
            case FloatVector: {
                List> data = new ArrayList<>();
                vectors.forEach(vector->data.add((List)vector.getData()));
                builder.withVectors(data);
                break;
            }
            case BinaryVector: {
                List data = new ArrayList<>();
                vectors.forEach(vector->data.add((ByteBuffer)vector.getData()));
                builder.withVectors(data);
                break;
            }
            default:
                throw new ParamException("Unsupported vector type.");
        }

        return builder.build();
    }

    public static FieldType convertV2Field(CreateCollectionReq.FieldSchema schema) {
        FieldType.Builder builder = FieldType.newBuilder()
                .withName(schema.getName())
                .withDataType(DataType.valueOf(schema.getDataType().name()))
                .withPrimaryKey(schema.getIsPrimaryKey())
                .withAutoID(schema.getAutoID())
                .withPartitionKey(schema.getIsPartitionKey());

        if (schema.getDimension() != null) {
            builder.withDimension(schema.getDimension());
        }
        if (schema.getMaxLength() != null) {
            builder.withMaxLength(schema.getMaxLength());
        }
        if (schema.getMaxCapacity() != null) {
            builder.withMaxCapacity(schema.getMaxLength());
        }
        if (schema.getElementType() != null) {
            builder.withElementType(DataType.valueOf(schema.getElementType().name()));
        }
        return builder.build();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy