Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
io.milvus.param.ParamUtils Maven / Gradle / Ivy
Go to download
Java SDK for Milvus, a distributed high-performance vector database.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.milvus.param;
import com.google.gson.*;
import com.google.gson.reflect.TypeToken;
import com.google.protobuf.ByteString;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.common.utils.JacksonUtils;
import io.milvus.exception.ParamException;
import io.milvus.grpc.*;
import io.milvus.param.collection.FieldType;
import io.milvus.param.dml.*;
import io.milvus.response.DescCollResponseWrapper;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
import java.util.stream.Collectors;
/**
* Utility functions for param classes
*/
public class ParamUtils {
private static final Gson GSON_INSTANCE = new Gson();
private static HashMap getTypeErrorMsgForColumnInsert() {
final HashMap typeErrMsg = new HashMap<>();
typeErrMsg.put(DataType.None, "Type mismatch for field '%s': the field type is illegal.");
typeErrMsg.put(DataType.Bool, "Type mismatch for field '%s': Bool field value type must be Boolean.");
typeErrMsg.put(DataType.Int8, "Type mismatch for field '%s': Int32/Int16/Int8 field value type must be Short or Integer.");
typeErrMsg.put(DataType.Int16, "Type mismatch for field '%s': Int32/Int16/Int8 field value type must be Short or Integer.");
typeErrMsg.put(DataType.Int32, "Type mismatch for field '%s': Int32/Int16/Int8 field value type must be Short or Integer.");
typeErrMsg.put(DataType.Int64, "Type mismatch for field '%s': Int64 field value type must be Long.");
typeErrMsg.put(DataType.Float, "Type mismatch for field '%s': Float field value type must be Float.");
typeErrMsg.put(DataType.Double, "Type mismatch for field '%s': Double field value type must be Double.");
typeErrMsg.put(DataType.String, "Type mismatch for field '%s': String field value type must be String."); // actually String type is useless
typeErrMsg.put(DataType.VarChar, "Type mismatch for field '%s': VarChar field value type must be String, and the string length must shorter than max_length.");
typeErrMsg.put(DataType.Array, "Type mismatch for field '%s': Array field value type must be List, each object type must be element_type, and the array length must be shorter than max_capacity.");
typeErrMsg.put(DataType.FloatVector, "Type mismatch for field '%s': Float vector field's value type must be List.");
typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer.");
return typeErrMsg;
}
private static HashMap getTypeErrorMsgForRowInsert() {
final HashMap typeErrMsg = new HashMap<>();
typeErrMsg.put(DataType.None, "Type mismatch for field '%s': the field type is illegal.");
typeErrMsg.put(DataType.Bool, "Type mismatch for field '%s': Bool field value type must be JsonPrimitive.");
typeErrMsg.put(DataType.Int8, "Type mismatch for field '%s': Int32/Int16/Int8 field value type must be JsonPrimitive of boolean.");
typeErrMsg.put(DataType.Int16, "Type mismatch for field '%s': Int32/Int16/Int8 field value type must be JsonPrimitive of number.");
typeErrMsg.put(DataType.Int32, "Type mismatch for field '%s': Int32/Int16/Int8 field value type must be JsonPrimitive of number.");
typeErrMsg.put(DataType.Int64, "Type mismatch for field '%s': Int64 field value type must be JsonPrimitive of number.");
typeErrMsg.put(DataType.Float, "Type mismatch for field '%s': Float field value type must be JsonPrimitive of number.");
typeErrMsg.put(DataType.Double, "Type mismatch for field '%s': Double field value type must be JsonPrimitive of number.");
typeErrMsg.put(DataType.String, "Type mismatch for field '%s': String field value type must be JsonPrimitive of string."); // actually String type is useless
typeErrMsg.put(DataType.VarChar, "Type mismatch for field '%s': VarChar field value type must be JsonPrimitive of string, and the string length must shorter than max_length.");
typeErrMsg.put(DataType.Array, "Type mismatch for field '%s': Array field value type must be JsonArray, each object type must be element_type, and the array length must be shorter than max_capacity.");
typeErrMsg.put(DataType.FloatVector, "Type mismatch for field '%s': Float vector field's value type must be JsonArray of List.");
typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be JsonArray of byte[].");
return typeErrMsg;
}
public static void checkFieldData(FieldType fieldSchema, InsertParam.Field fieldData) {
List> values = fieldData.getValues();
checkFieldData(fieldSchema, values, false);
}
private static int calculateBinVectorDim(DataType dataType, int byteCount) {
if (dataType == DataType.BinaryVector) {
return byteCount*8; // for BinaryVector, each byte is 8 dimensions
} else {
throw new ParamException(String.format("%s is not binary vector type", dataType.name()));
}
}
public static void checkFieldData(FieldType fieldSchema, List> values, boolean verifyElementType) {
HashMap errMsgs = getTypeErrorMsgForColumnInsert();
DataType dataType = verifyElementType ? fieldSchema.getElementType() : fieldSchema.getDataType();
if (verifyElementType && values.size() > fieldSchema.getMaxCapacity()) {
throw new ParamException(String.format("Array field '%s' length: %d exceeds max capacity: %d",
fieldSchema.getName(), values.size(), fieldSchema.getMaxCapacity()));
}
switch (dataType) {
case FloatVector: {
int dim = fieldSchema.getDimension();
for (int i = 0; i < values.size(); ++i) {
// is List<> ?
Object value = values.get(i);
if (!(value instanceof List)) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
// is List ?
List> temp = (List>)value;
for (Object v : temp) {
if (!(v instanceof Float)) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
}
// check dimension
if (temp.size() != dim) {
String msg = "Incorrect dimension for field '%s': the no.%d vector's dimension: %d is not equal to field's dimension: %d";
throw new ParamException(String.format(msg, fieldSchema.getName(), i, temp.size(), dim));
}
}
break;
}
case BinaryVector: {
int dim = fieldSchema.getDimension();
for (int i = 0; i < values.size(); ++i) {
Object value = values.get(i);
// is ByteBuffer?
if (!(value instanceof ByteBuffer)) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
// check dimension
ByteBuffer v = (ByteBuffer)value;
int real_dim = calculateBinVectorDim(dataType, v.position());
if (real_dim != dim) {
String msg = "Incorrect dimension for field '%s': the no.%d vector's dimension: %d is not equal to field's dimension: %d";
throw new ParamException(String.format(msg, fieldSchema.getName(), i, real_dim, dim));
}
}
break;
}
case Int64:
for (Object value : values) {
if (!(value instanceof Long)) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
}
break;
case Int32:
case Int16:
case Int8:
for (Object value : values) {
if (!(value instanceof Short) && !(value instanceof Integer)) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
}
break;
case Bool:
for (Object value : values) {
if (!(value instanceof Boolean)) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
}
break;
case Float:
for (Object value : values) {
if (!(value instanceof Float)) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
}
break;
case Double:
for (Object value : values) {
if (!(value instanceof Double)) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
}
break;
case VarChar:
case String:
for (Object value : values) {
if (!(value instanceof String)) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
if (((String) value).length() > fieldSchema.getMaxLength()) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
}
break;
case JSON:
for (Object value : values) {
if (!(value instanceof JsonElement)) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
}
break;
case Array:
for (Object value : values) {
if (!(value instanceof List)) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
List> array = (List>)value;
if (array.size() > fieldSchema.getMaxCapacity()) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
checkFieldData(fieldSchema, array, true);
}
break;
default:
throw new ParamException("Unsupported data type returned by FieldData");
}
}
public static Object checkFieldValue(FieldType fieldSchema, JsonElement value) {
HashMap errMsgs = getTypeErrorMsgForRowInsert();
DataType dataType = fieldSchema.getDataType();
switch (dataType) {
case FloatVector: {
if (!(value.isJsonArray())) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
int dim = fieldSchema.getDimension();
try {
List vector = GSON_INSTANCE.fromJson(value, new TypeToken>() {}.getType());
if (vector.size() != dim) {
String msg = "Incorrect dimension for field '%s': dimension: %d is not equal to field's dimension: %d";
throw new ParamException(String.format(msg, fieldSchema.getName(), vector.size(), dim));
}
return vector; // return List for genFieldData()
} catch (JsonSyntaxException e) {
throw new ParamException(String.format("Unable to convert JsonArray to List for field '%s'. Reason: %s",
fieldSchema.getName(), e.getCause().getMessage()));
}
}
case BinaryVector: {
if (!(value.isJsonArray())) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
int dim = fieldSchema.getDimension();
try {
byte[] v = GSON_INSTANCE.fromJson(value, new TypeToken() {}.getType());
int real_dim = calculateBinVectorDim(dataType, v.length);
if (real_dim != dim) {
String msg = "Incorrect dimension for field '%s': dimension: %d is not equal to field's dimension: %d";
throw new ParamException(String.format(msg, fieldSchema.getName(), real_dim, dim));
}
return ByteBuffer.wrap(v); // return ByteBuffer for genFieldData()
} catch (JsonSyntaxException e) {
throw new ParamException(String.format("Unable to convert JsonArray to List for field '%s'. Reason: %s",
fieldSchema.getName(), e.getCause().getMessage()));
}
}
case Int64:
if (!(value.isJsonPrimitive())) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
return value.getAsLong(); // return long for genFieldData()
case Int32:
case Int16:
case Int8:
if (!(value.isJsonPrimitive())) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
return value.getAsInt(); // return int for genFieldData()
case Bool:
if (!(value.isJsonPrimitive())) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
return value.getAsBoolean(); // return boolean for genFieldData()
case Float:
if (!(value.isJsonPrimitive())) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
return value.getAsFloat(); // return float for genFieldData()
case Double:
if (!(value.isJsonPrimitive())) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
return value.getAsDouble(); // return double for genFieldData()
case VarChar:
case String:
if (!(value.isJsonPrimitive())) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
JsonPrimitive p = value.getAsJsonPrimitive();
if (!p.isString()) {
throw new ParamException(String.format("JsonPrimitive should be String type for field '%s'", fieldSchema.getName()));
}
String str = p.getAsString();
if (str.length() > fieldSchema.getMaxLength()) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
return str; // return String for genFieldData()
case JSON:
return value; // return JsonElement for genFieldData()
case Array:
if (!(value.isJsonArray())) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
List array = convertJsonArray(value.getAsJsonArray(), fieldSchema.getElementType(), fieldSchema.getName());
if (array.size() > fieldSchema.getMaxCapacity()) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}
return array; // return List for genFieldData()
default:
throw new ParamException("Unsupported data type returned by FieldData");
}
}
public static List convertJsonArray(JsonArray jsonArray, DataType elementType, String fieldName) {
try {
switch (elementType) {
case Int64:
return GSON_INSTANCE.fromJson(jsonArray, new TypeToken>() {}.getType());
case Int32:
case Int16:
case Int8:
return GSON_INSTANCE.fromJson(jsonArray, new TypeToken>() {}.getType());
case Bool:
return GSON_INSTANCE.fromJson(jsonArray, new TypeToken>() {}.getType());
case Float:
return GSON_INSTANCE.fromJson(jsonArray, new TypeToken>() {}.getType());
case Double:
return GSON_INSTANCE.fromJson(jsonArray, new TypeToken>() {}.getType());
case VarChar:
return GSON_INSTANCE.fromJson(jsonArray, new TypeToken>() {}.getType());
default:
throw new ParamException(String.format("Unsupported element type of Array field '%s'", fieldName));
}
} catch (JsonSyntaxException e) {
throw new ParamException(String.format("Unable to convert JsonArray to List for field '%s'. Reason: %s",
fieldName, e.getCause().getMessage()));
}
}
/**
* Checks if a string is empty or null.
* Throws {@link ParamException} if the string is empty of null.
*
* @param target target string
* @param name a name to describe this string
*/
public static void CheckNullEmptyString(String target, String name) throws ParamException {
if (target == null || StringUtils.isBlank(target)) {
throw new ParamException(name + " cannot be null or empty");
}
}
/**
* Checks if a string is null.
* Throws {@link ParamException} if the string is null.
*
* @param target target string
* @param name a name to describe this string
*/
public static void CheckNullString(String target, String name) throws ParamException {
if (target == null) {
throw new ParamException(name + " cannot be null");
}
}
public static class InsertBuilderWrapper {
private InsertRequest.Builder insertBuilder;
private UpsertRequest.Builder upsertBuilder;
public InsertBuilderWrapper(@NonNull InsertParam requestParam,
DescCollResponseWrapper wrapper) {
String collectionName = requestParam.getCollectionName();
// generate insert request builder
MsgBase msgBase = MsgBase.newBuilder().setMsgType(MsgType.Insert).build();
insertBuilder = InsertRequest.newBuilder()
.setCollectionName(collectionName)
.setBase(msgBase)
.setNumRows(requestParam.getRowCount());
if (StringUtils.isNotEmpty(requestParam.getDatabaseName())) {
insertBuilder.setDbName(requestParam.getDatabaseName());
}
fillFieldsData(requestParam, wrapper);
}
public InsertBuilderWrapper(@NonNull UpsertParam requestParam,
DescCollResponseWrapper wrapper) {
String collectionName = requestParam.getCollectionName();
// currently, not allow to upsert for collection whose primary key is auto-generated
FieldType pk = wrapper.getPrimaryField();
if (pk.isAutoID()) {
throw new ParamException(String.format("Upsert don't support autoID==True, collection: %s",
requestParam.getCollectionName()));
}
// generate upsert request builder
MsgBase msgBase = MsgBase.newBuilder().setMsgType(MsgType.Insert).build();
upsertBuilder = UpsertRequest.newBuilder()
.setCollectionName(collectionName)
.setBase(msgBase)
.setNumRows(requestParam.getRowCount());
if (StringUtils.isNotEmpty(requestParam.getDatabaseName())) {
upsertBuilder.setDbName(requestParam.getDatabaseName());
}
fillFieldsData(requestParam, wrapper);
}
private void addFieldsData(io.milvus.grpc.FieldData value) {
if (insertBuilder != null) {
insertBuilder.addFieldsData(value);
} else if (upsertBuilder != null) {
upsertBuilder.addFieldsData(value);
}
}
private void setPartitionName(String value) {
if (insertBuilder != null) {
insertBuilder.setPartitionName(value);
} else if (upsertBuilder != null) {
upsertBuilder.setPartitionName(value);
}
}
private void fillFieldsData(InsertParam requestParam, DescCollResponseWrapper wrapper) {
// set partition name only when there is no partition key field
String partitionName = requestParam.getPartitionName();
boolean isPartitionKeyEnabled = false;
for (FieldType fieldType : wrapper.getFields()) {
if (fieldType.isPartitionKey()) {
isPartitionKeyEnabled = true;
break;
}
}
if (isPartitionKeyEnabled) {
if (partitionName != null && !partitionName.isEmpty()) {
String msg = String.format("Collection %s has partition key, not allow to specify partition name",
requestParam.getCollectionName());
throw new ParamException(msg);
}
} else if (partitionName != null) {
this.setPartitionName(partitionName);
}
// convert insert data
List columnFields = requestParam.getFields();
List rowFields = requestParam.getRows();
if (CollectionUtils.isNotEmpty(columnFields)) {
checkAndSetColumnData(wrapper, columnFields);
} else {
checkAndSetRowData(wrapper, rowFields);
}
}
private void checkAndSetColumnData(DescCollResponseWrapper wrapper, List fields) {
List fieldTypes = wrapper.getFields();
// gen fieldData
// make sure the field order must be consisted with collection schema
for (FieldType fieldType : fieldTypes) {
boolean found = false;
for (InsertParam.Field field : fields) {
if (field.getName().equals(fieldType.getName())) {
if (fieldType.isAutoID()) {
String msg = String.format("The primary key: %s is auto generated, no need to input.",
fieldType.getName());
throw new ParamException(msg);
}
checkFieldData(fieldType, field);
found = true;
this.addFieldsData(genFieldData(fieldType, field.getValues()));
break;
}
}
if (!found && !fieldType.isAutoID()) {
throw new ParamException(String.format("The field: %s is not provided.", fieldType.getName()));
}
}
// deal with dynamicField
if (wrapper.getEnableDynamicField()) {
for (InsertParam.Field field : fields) {
if (field.getName().equals(Constant.DYNAMIC_FIELD_NAME)) {
FieldType dynamicType = FieldType.newBuilder()
.withName(Constant.DYNAMIC_FIELD_NAME)
.withDataType(DataType.JSON)
.withIsDynamic(true)
.build();
checkFieldData(dynamicType, field);
this.addFieldsData(genFieldData(dynamicType, field.getValues(), true));
break;
}
}
}
}
private void checkAndSetRowData(DescCollResponseWrapper wrapper, List rows) {
List fieldTypes = wrapper.getFields();
Map nameInsertInfo = new HashMap<>();
InsertDataInfo insertDynamicDataInfo = InsertDataInfo.builder().fieldType(
FieldType.newBuilder()
.withName(Constant.DYNAMIC_FIELD_NAME)
.withDataType(DataType.JSON)
.withIsDynamic(true)
.build())
.data(new LinkedList<>()).build();
for (JsonObject row : rows) {
for (FieldType fieldType : fieldTypes) {
String fieldName = fieldType.getName();
InsertDataInfo insertDataInfo = nameInsertInfo.getOrDefault(fieldName, InsertDataInfo.builder()
.fieldType(fieldType).data(new LinkedList<>()).build());
// check normalField
JsonElement rowFieldData = row.get(fieldName);
if (rowFieldData != null && !rowFieldData.isJsonNull()) {
if (fieldType.isAutoID()) {
String msg = String.format("The primary key: %s is auto generated, no need to input.", fieldName);
throw new ParamException(msg);
}
Object fieldValue = checkFieldValue(fieldType, rowFieldData);
insertDataInfo.getData().add(fieldValue);
nameInsertInfo.put(fieldName, insertDataInfo);
} else {
// check if autoId
if (!fieldType.isAutoID()) {
String msg = String.format("The field: %s is not provided.", fieldType.getName());
throw new ParamException(msg);
}
}
}
// deal with dynamicField
if (wrapper.getEnableDynamicField()) {
JsonObject dynamicField = new JsonObject();
for (String rowFieldName : row.keySet()) {
if (!nameInsertInfo.containsKey(rowFieldName)) {
dynamicField.add(rowFieldName, row.get(rowFieldName));
}
}
insertDynamicDataInfo.getData().add(dynamicField);
}
}
for (String fieldNameKey : nameInsertInfo.keySet()) {
InsertDataInfo insertDataInfo = nameInsertInfo.get(fieldNameKey);
this.addFieldsData(genFieldData(insertDataInfo.getFieldType(), insertDataInfo.getData()));
}
if (wrapper.getEnableDynamicField()) {
this.addFieldsData(genFieldData(insertDynamicDataInfo.getFieldType(), insertDynamicDataInfo.getData(), Boolean.TRUE));
}
}
public InsertRequest buildInsertRequest() {
if (insertBuilder != null) {
return insertBuilder.build();
}
throw new ParamException("Unable to build insert request since no input");
}
public UpsertRequest buildUpsertRequest() {
if (upsertBuilder != null) {
return upsertBuilder.build();
}
throw new ParamException("Unable to build upsert request since no input");
}
}
@SuppressWarnings("unchecked")
public static ByteString convertPlaceholder(List> vectors, PlaceholderType placeType) throws ParamException {
PlaceholderType plType = PlaceholderType.None;
List byteStrings = new ArrayList<>();
for (Object vector : vectors) {
if (vector instanceof List) {
plType = PlaceholderType.FloatVector;
List list = (List) vector;
ByteBuffer buf = ByteBuffer.allocate(Float.BYTES * list.size());
buf.order(ByteOrder.LITTLE_ENDIAN);
list.forEach(buf::putFloat);
byte[] array = buf.array();
ByteString bs = ByteString.copyFrom(array);
byteStrings.add(bs);
} else if (vector instanceof ByteBuffer) {
plType = PlaceholderType.BinaryVector;
ByteBuffer buf = (ByteBuffer) vector;
byte[] array = buf.array();
ByteString bs = ByteString.copyFrom(array);
byteStrings.add(bs);
} else {
String msg = "Search target vector type is illegal." +
" Only allow List for FloatVector, ByteBuffer for BinaryVector";
throw new ParamException(msg);
}
}
// force specify PlaceholderType
if (placeType != PlaceholderType.None) {
plType = placeType;
}
PlaceholderValue.Builder pldBuilder = PlaceholderValue.newBuilder()
.setTag(Constant.VECTOR_TAG)
.setType(plType);
byteStrings.forEach(pldBuilder::addValues);
PlaceholderValue plv = pldBuilder.build();
PlaceholderGroup placeholderGroup = PlaceholderGroup.newBuilder()
.addPlaceholders(plv)
.build();
return placeholderGroup.toByteString();
}
@SuppressWarnings("unchecked")
public static SearchRequest convertSearchParam(@NonNull SearchParam requestParam) throws ParamException {
SearchRequest.Builder builder = SearchRequest.newBuilder()
.setCollectionName(requestParam.getCollectionName());
if (!requestParam.getPartitionNames().isEmpty()) {
requestParam.getPartitionNames().forEach(builder::addPartitionNames);
}
if (StringUtils.isNotEmpty(requestParam.getDatabaseName())) {
builder.setDbName(requestParam.getDatabaseName());
}
// prepare target vectors
ByteString byteStr = convertPlaceholder(requestParam.getVectors(), PlaceholderType.None);
builder.setPlaceholderGroup(byteStr);
builder.setNq(requestParam.getNQ());
// search parameters
builder.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.VECTOR_FIELD)
.setValue(requestParam.getVectorFieldName())
.build())
.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.TOP_K)
.setValue(String.valueOf(requestParam.getTopK()))
.build())
.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.ROUND_DECIMAL)
.setValue(String.valueOf(requestParam.getRoundDecimal()))
.build())
.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.IGNORE_GROWING)
.setValue(String.valueOf(requestParam.isIgnoreGrowing()))
.build());
if (!Objects.equals(requestParam.getMetricType(), MetricType.None.name())) {
builder.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.METRIC_TYPE)
.setValue(requestParam.getMetricType())
.build());
}
if (null != requestParam.getParams() && !requestParam.getParams().isEmpty()) {
try {
Map paramMap = JacksonUtils.fromJson(requestParam.getParams(),Map.class);
String offset = paramMap.getOrDefault(Constant.OFFSET, 0).toString();
builder.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.OFFSET)
.setValue(offset)
.build());
builder.addSearchParams(
KeyValuePair.newBuilder()
.setKey(Constant.PARAMS)
.setValue(requestParam.getParams())
.build());
} catch (IllegalArgumentException e) {
throw new ParamException(e.getMessage() + e.getCause().getMessage());
}
}
if (!requestParam.getOutFields().isEmpty()) {
requestParam.getOutFields().forEach(builder::addOutputFields);
}
// always use expression since dsl is discarded
builder.setDslType(DslType.BoolExprV1);
if (requestParam.getExpr() != null && !requestParam.getExpr().isEmpty()) {
builder.setDsl(requestParam.getExpr());
}
long guaranteeTimestamp = getGuaranteeTimestamp(requestParam.getConsistencyLevel(),
requestParam.getGuaranteeTimestamp(), requestParam.getGracefulTime());
builder.setTravelTimestamp(requestParam.getTravelTimestamp());
builder.setGuaranteeTimestamp(guaranteeTimestamp);
// a new parameter from v2.2.9, if user didn't specify consistency level, set this parameter to true
if (requestParam.getConsistencyLevel() == null) {
builder.setUseDefaultConsistency(true);
} else {
builder.setConsistencyLevelValue(requestParam.getConsistencyLevel().getCode());
}
return builder.build();
}
public static QueryRequest convertQueryParam(@NonNull QueryParam requestParam) {
long guaranteeTimestamp = getGuaranteeTimestamp(requestParam.getConsistencyLevel(),
requestParam.getGuaranteeTimestamp(), requestParam.getGracefulTime());
QueryRequest.Builder builder = QueryRequest.newBuilder()
.setCollectionName(requestParam.getCollectionName())
.addAllPartitionNames(requestParam.getPartitionNames())
.addAllOutputFields(requestParam.getOutFields())
.setExpr(requestParam.getExpr())
.setTravelTimestamp(requestParam.getTravelTimestamp())
.setGuaranteeTimestamp(guaranteeTimestamp);
if (StringUtils.isNotEmpty(requestParam.getDatabaseName())) {
builder.setDbName(requestParam.getDatabaseName());
}
// a new parameter from v2.2.9, if user didn't specify consistency level, set this parameter to true
if (requestParam.getConsistencyLevel() == null) {
builder.setUseDefaultConsistency(true);
} else {
builder.setConsistencyLevelValue(requestParam.getConsistencyLevel().getCode());
}
// set offset and limit value.
// directly pass the two values, the server will verify them.
long offset = requestParam.getOffset();
if (offset > 0) {
builder.addQueryParams(KeyValuePair.newBuilder()
.setKey(Constant.OFFSET)
.setValue(String.valueOf(offset))
.build());
}
long limit = requestParam.getLimit();
if (limit > 0) {
builder.addQueryParams(KeyValuePair.newBuilder()
.setKey(Constant.LIMIT)
.setValue(String.valueOf(limit))
.build());
}
// ignore growing
builder.addQueryParams(KeyValuePair.newBuilder()
.setKey(Constant.IGNORE_GROWING)
.setValue(String.valueOf(requestParam.isIgnoreGrowing()))
.build());
return builder.build();
}
private static long getGuaranteeTimestamp(ConsistencyLevelEnum consistencyLevel,
long guaranteeTimestamp, Long gracefulTime){
if(consistencyLevel == null){
return 1L;
}
switch (consistencyLevel){
case STRONG:
guaranteeTimestamp = 0L;
break;
case BOUNDED:
guaranteeTimestamp = (new Date()).getTime() - gracefulTime;
break;
case EVENTUALLY:
guaranteeTimestamp = 1L;
break;
}
return guaranteeTimestamp;
}
public static boolean isVectorDataType(DataType dataType) {
Set vectorDataType = new HashSet() {{
add(DataType.FloatVector);
add(DataType.BinaryVector);
}};
return vectorDataType.contains(dataType);
}
public static FieldData genFieldData(FieldType fieldType, List> objects) {
return genFieldData(fieldType, objects, Boolean.FALSE);
}
public static FieldData genFieldData(FieldType fieldType, List> objects, boolean isDynamic) {
if (objects == null) {
throw new ParamException("Cannot generate FieldData from null object");
}
DataType dataType = fieldType.getDataType();
String fieldName = fieldType.getName();
FieldData.Builder builder = FieldData.newBuilder();
if (isVectorDataType(dataType)) {
VectorField vectorField = genVectorField(dataType, objects);
return builder.setFieldName(fieldName).setType(dataType).setVectors(vectorField).build();
} else {
ScalarField scalarField = genScalarField(fieldType, objects);
if (isDynamic) {
return builder.setType(dataType).setScalars(scalarField).setIsDynamic(true).build();
}
return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField).build();
}
}
@SuppressWarnings("unchecked")
private static VectorField genVectorField(DataType dataType, List> objects) {
if (dataType == DataType.FloatVector) {
List floats = new ArrayList<>();
// each object is List
for (Object object : objects) {
if (object instanceof List) {
List list = (List) object;
floats.addAll(list);
} else {
throw new ParamException("The type of FloatVector must be List");
}
}
int dim = floats.size() / objects.size();
FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build();
return VectorField.newBuilder().setDim(dim).setFloatVector(floatArray).build();
} else if (dataType == DataType.BinaryVector) {
ByteBuffer totalBuf = null;
int dim = 0;
// each object is ByteBuffer
for (Object object : objects) {
ByteBuffer buf = (ByteBuffer) object;
if (totalBuf == null) {
totalBuf = ByteBuffer.allocate(buf.limit() * objects.size());
totalBuf.put(buf.array());
dim = calculateBinVectorDim(dataType, buf.limit());
} else {
totalBuf.put(buf.array());
}
}
assert totalBuf != null;
ByteString byteString = ByteString.copyFrom(totalBuf.array());
return VectorField.newBuilder().setDim(dim).setBinaryVector(byteString).build();
}
throw new ParamException("Illegal vector dataType:" + dataType);
}
private static ScalarField genScalarField(FieldType fieldType, List> objects) {
if (fieldType.getDataType() == DataType.Array) {
ArrayArray.Builder builder = ArrayArray.newBuilder();
for (Object object : objects) {
List> temp = (List>)object;
ScalarField arrayField = genScalarField(fieldType.getElementType(), temp);
builder.addData(arrayField);
}
return ScalarField.newBuilder().setArrayData(builder.build()).build();
} else {
return genScalarField(fieldType.getDataType(), objects);
}
}
private static ScalarField genScalarField(DataType dataType, List> objects) {
switch (dataType) {
case None:
case UNRECOGNIZED:
throw new ParamException("Cannot support this dataType:" + dataType);
case Int64: {
List longs = objects.stream().map(p -> (Long) p).collect(Collectors.toList());
LongArray longArray = LongArray.newBuilder().addAllData(longs).build();
return ScalarField.newBuilder().setLongData(longArray).build();
}
case Int32:
case Int16:
case Int8: {
List integers = objects.stream().map(p -> p instanceof Short ? ((Short) p).intValue() : (Integer) p).collect(Collectors.toList());
IntArray intArray = IntArray.newBuilder().addAllData(integers).build();
return ScalarField.newBuilder().setIntData(intArray).build();
}
case Bool: {
List booleans = objects.stream().map(p -> (Boolean) p).collect(Collectors.toList());
BoolArray boolArray = BoolArray.newBuilder().addAllData(booleans).build();
return ScalarField.newBuilder().setBoolData(boolArray).build();
}
case Float: {
List floats = objects.stream().map(p -> (Float) p).collect(Collectors.toList());
FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build();
return ScalarField.newBuilder().setFloatData(floatArray).build();
}
case Double: {
List doubles = objects.stream().map(p -> (Double) p).collect(Collectors.toList());
DoubleArray doubleArray = DoubleArray.newBuilder().addAllData(doubles).build();
return ScalarField.newBuilder().setDoubleData(doubleArray).build();
}
case String:
case VarChar: {
List strings = objects.stream().map(p -> (String) p).collect(Collectors.toList());
StringArray stringArray = StringArray.newBuilder().addAllData(strings).build();
return ScalarField.newBuilder().setStringData(stringArray).build();
}
case JSON: {
List byteStrings = objects.stream().map(p -> ByteString.copyFromUtf8(p.toString()))
.collect(Collectors.toList());
JSONArray jsonArray = JSONArray.newBuilder().addAllData(byteStrings).build();
return ScalarField.newBuilder().setJsonData(jsonArray).build();
}
default:
throw new ParamException("Illegal scalar dataType:" + dataType);
}
}
/**
* Convert a grpc field schema to client field schema
*
* @param field FieldSchema object
* @return {@link FieldType} schema of the field
*/
public static FieldType ConvertField(@NonNull FieldSchema field) {
FieldType.Builder builder = FieldType.newBuilder()
.withName(field.getName())
.withDescription(field.getDescription())
.withPrimaryKey(field.getIsPrimaryKey())
.withPartitionKey(field.getIsPartitionKey())
.withAutoID(field.getAutoID())
.withDataType(field.getDataType())
.withElementType(field.getElementType())
.withIsDynamic(field.getIsDynamic());
if (field.getIsDynamic()) {
builder.withIsDynamic(true);
}
List keyValuePairs = field.getTypeParamsList();
keyValuePairs.forEach((kv) -> builder.addTypeParam(kv.getKey(), kv.getValue()));
return builder.build();
}
/**
* Convert a client field schema to grpc field schema
*
* @param field {@link FieldType} object
* @return {@link FieldSchema} schema of the field
*/
public static FieldSchema ConvertField(@NonNull FieldType field) {
FieldSchema.Builder builder = FieldSchema.newBuilder()
.setName(field.getName())
.setDescription(field.getDescription())
.setIsPrimaryKey(field.isPrimaryKey())
.setIsPartitionKey(field.isPartitionKey())
.setAutoID(field.isAutoID())
.setDataType(field.getDataType())
.setElementType(field.getElementType())
.setIsDynamic(field.isDynamic());
// assemble typeParams for CollectionSchema
List typeParamsList = AssembleKvPair(field.getTypeParams());
if (CollectionUtils.isNotEmpty(typeParamsList)) {
typeParamsList.forEach(builder::addTypeParams);
}
return builder.build();
}
public static List AssembleKvPair(Map sourceMap) {
List result = new ArrayList<>();
if (MapUtils.isNotEmpty(sourceMap)) {
sourceMap.forEach((key, value) -> {
KeyValuePair kv = KeyValuePair.newBuilder()
.setKey(key)
.setValue(value).build();
result.add(kv);
});
}
return result;
}
@Builder
@Getter
public static class InsertDataInfo {
private final FieldType fieldType;
private final LinkedList data;
}
}