io.trino.plugin.ml.ModelUtils Maven / Gradle / Ivy
Show all versions of trino-ml Show documentation
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.ml;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.hash.HashCode;
import com.google.common.hash.Hashing;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.spi.block.Block;
import io.trino.spi.block.SqlMap;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static io.airlift.slice.SizeOf.SIZE_OF_INT;
import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
public final class ModelUtils
{
private static final int VERSION_OFFSET = 0;
private static final int HASH_OFFSET = VERSION_OFFSET + SIZE_OF_INT;
private static final int ALGORITHM_OFFSET = HASH_OFFSET + 32;
private static final int HYPERPARAMETER_LENGTH_OFFSET = ALGORITHM_OFFSET + SIZE_OF_INT;
private static final int HYPERPARAMETERS_OFFSET = HYPERPARAMETER_LENGTH_OFFSET + SIZE_OF_INT;
private static final int CURRENT_FORMAT_VERSION = 1;
// These ids are serialized to disk. Do not change them.
@VisibleForTesting
static final BiMap, Integer> MODEL_SERIALIZATION_IDS;
static {
ImmutableBiMap.Builder, Integer> builder = ImmutableBiMap.builder();
builder.put(SvmClassifier.class, 1);
builder.put(SvmRegressor.class, 2);
builder.put(FeatureVectorUnitNormalizer.class, 3);
builder.put(ClassifierFeatureTransformer.class, 4);
builder.put(RegressorFeatureTransformer.class, 5);
builder.put(FeatureUnitNormalizer.class, 6);
builder.put(StringClassifierAdapter.class, 7);
MODEL_SERIALIZATION_IDS = builder.build();
}
private ModelUtils() {}
/**
* Serializes the model using the following format
* int: format version
* byte[32]: SHA256 hash of all following data
* int: id of algorithm
* int: length of hyperparameters section
* byte[]: hyperparameters (currently not used)
* long: length of data section
* byte[]: model data
*
* note: all multibyte values are in little endian
*/
public static Slice serialize(Model model)
{
requireNonNull(model, "model is null");
Integer id = MODEL_SERIALIZATION_IDS.get(model.getClass());
requireNonNull(id, "id is null");
int size = HYPERPARAMETERS_OFFSET;
// hyperparameters aren't implemented yet
byte[] hyperparameters = new byte[0];
size += hyperparameters.length;
int dataLengthOffset = size;
size += SIZE_OF_LONG;
int dataOffset = size;
byte[] data = model.getSerializedData();
size += data.length;
Slice slice = Slices.allocate(size);
slice.setInt(VERSION_OFFSET, CURRENT_FORMAT_VERSION);
slice.setInt(ALGORITHM_OFFSET, id);
slice.setInt(HYPERPARAMETER_LENGTH_OFFSET, hyperparameters.length);
slice.setBytes(HYPERPARAMETERS_OFFSET, hyperparameters);
slice.setLong(dataLengthOffset, data.length);
slice.setBytes(dataOffset, data);
byte[] modelHash = Hashing.sha256().hashBytes(slice.getBytes(ALGORITHM_OFFSET, slice.length() - ALGORITHM_OFFSET)).asBytes();
checkState(modelHash.length == 32, "sha256 hash code expected to be 32 bytes");
slice.setBytes(HASH_OFFSET, modelHash);
return slice;
}
public static HashCode modelHash(Slice slice)
{
return HashCode.fromBytes(slice.getBytes(HASH_OFFSET, 32));
}
public static Model deserialize(byte[] data)
{
return deserialize(Slices.wrappedBuffer(data));
}
public static Model deserialize(Slice slice)
{
int version = slice.getInt(VERSION_OFFSET);
checkArgument(version == CURRENT_FORMAT_VERSION, "Unsupported version: %s", version);
byte[] modelHashBytes = slice.getBytes(HASH_OFFSET, 32);
HashCode expectedHash = HashCode.fromBytes(modelHashBytes);
HashCode actualHash = Hashing.sha256().hashBytes(slice.getBytes(ALGORITHM_OFFSET, slice.length() - ALGORITHM_OFFSET));
checkArgument(actualHash.equals(expectedHash), "model hash does not match data");
int id = slice.getInt(ALGORITHM_OFFSET);
Class extends Model> algorithm = MODEL_SERIALIZATION_IDS.inverse().get(id);
requireNonNull(algorithm, format("Unsupported algorith %d", id));
int hyperparameterLength = slice.getInt(HYPERPARAMETER_LENGTH_OFFSET);
int dataLengthOffset = HYPERPARAMETERS_OFFSET + hyperparameterLength;
long dataLength = slice.getLong(dataLengthOffset);
int dataOffset = dataLengthOffset + SIZE_OF_LONG;
byte[] data = slice.getBytes(dataOffset, (int) dataLength);
try {
Method deserialize = algorithm.getMethod("deserialize", byte[].class);
return (Model) deserialize.invoke(null, new Object[] {data});
}
catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
public static byte[] serializeModels(Model... models)
{
List serializedModels = new ArrayList<>();
int size = SIZE_OF_INT + SIZE_OF_INT * models.length;
for (Model model : models) {
byte[] bytes = serialize(model).getBytes();
size += bytes.length;
serializedModels.add(bytes);
}
Slice slice = Slices.allocate(size);
slice.setInt(0, models.length);
for (int i = 0; i < models.length; i++) {
slice.setInt(SIZE_OF_INT * (i + 1), serializedModels.get(i).length);
}
int offset = SIZE_OF_INT + SIZE_OF_INT * models.length;
for (byte[] bytes : serializedModels) {
slice.setBytes(offset, bytes);
offset += bytes.length;
}
return slice.getBytes();
}
public static List deserializeModels(byte[] bytes)
{
Slice slice = Slices.wrappedBuffer(bytes);
int numModels = slice.getInt(0);
int offset = SIZE_OF_INT + SIZE_OF_INT * numModels;
ImmutableList.Builder models = ImmutableList.builder();
for (int i = 0; i < numModels; i++) {
int length = slice.getInt(SIZE_OF_INT * (i + 1));
models.add(deserialize(slice.getBytes(offset, length)));
offset += length;
}
return models.build();
}
//TODO: instead of having this function, we should add feature extractors that extend Model and extract features from Strings
public static FeatureVector toFeatures(SqlMap sqlMap)
{
Map features = new HashMap<>();
if (sqlMap != null) {
int rawOffset = sqlMap.getRawOffset();
Block rawKeyBlock = sqlMap.getRawKeyBlock();
Block rawValueBlock = sqlMap.getRawValueBlock();
for (int i = 0; i < sqlMap.getSize(); i++) {
features.put((int) BIGINT.getLong(rawKeyBlock, rawOffset + i), DOUBLE.getDouble(rawValueBlock, rawOffset + i));
}
}
return new FeatureVector(features);
}
}