All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.plugin.ml.ModelUtils Maven / Gradle / Ivy

The newest version!
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.plugin.ml;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.hash.HashCode;
import com.google.common.hash.Hashing;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.spi.block.Block;
import io.trino.spi.block.SqlMap;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static io.airlift.slice.SizeOf.SIZE_OF_INT;
import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public final class ModelUtils
{
    private static final int VERSION_OFFSET = 0;
    private static final int HASH_OFFSET = VERSION_OFFSET + SIZE_OF_INT;
    private static final int ALGORITHM_OFFSET = HASH_OFFSET + 32;
    private static final int HYPERPARAMETER_LENGTH_OFFSET = ALGORITHM_OFFSET + SIZE_OF_INT;
    private static final int HYPERPARAMETERS_OFFSET = HYPERPARAMETER_LENGTH_OFFSET + SIZE_OF_INT;

    private static final int CURRENT_FORMAT_VERSION = 1;

    // These ids are serialized to disk. Do not change them.
    @VisibleForTesting
    static final BiMap, Integer> MODEL_SERIALIZATION_IDS;

    static {
        ImmutableBiMap.Builder, Integer> builder = ImmutableBiMap.builder();
        builder.put(SvmClassifier.class, 1);
        builder.put(SvmRegressor.class, 2);
        builder.put(FeatureVectorUnitNormalizer.class, 3);
        builder.put(ClassifierFeatureTransformer.class, 4);
        builder.put(RegressorFeatureTransformer.class, 5);
        builder.put(FeatureUnitNormalizer.class, 6);
        builder.put(StringClassifierAdapter.class, 7);

        MODEL_SERIALIZATION_IDS = builder.build();
    }

    private ModelUtils() {}

    /**
     * Serializes the model using the following format
     * int: format version
     * byte[32]: SHA256 hash of all following data
     * int: id of algorithm
     * int: length of hyperparameters section
     * byte[]: hyperparameters (currently not used)
     * long: length of data section
     * byte[]: model data
     * 

* note: all multibyte values are in little endian */ public static Slice serialize(Model model) { requireNonNull(model, "model is null"); Integer id = MODEL_SERIALIZATION_IDS.get(model.getClass()); requireNonNull(id, "id is null"); int size = HYPERPARAMETERS_OFFSET; // hyperparameters aren't implemented yet byte[] hyperparameters = new byte[0]; size += hyperparameters.length; int dataLengthOffset = size; size += SIZE_OF_LONG; int dataOffset = size; byte[] data = model.getSerializedData(); size += data.length; Slice slice = Slices.allocate(size); slice.setInt(VERSION_OFFSET, CURRENT_FORMAT_VERSION); slice.setInt(ALGORITHM_OFFSET, id); slice.setInt(HYPERPARAMETER_LENGTH_OFFSET, hyperparameters.length); slice.setBytes(HYPERPARAMETERS_OFFSET, hyperparameters); slice.setLong(dataLengthOffset, data.length); slice.setBytes(dataOffset, data); byte[] modelHash = Hashing.sha256().hashBytes(slice.getBytes(ALGORITHM_OFFSET, slice.length() - ALGORITHM_OFFSET)).asBytes(); checkState(modelHash.length == 32, "sha256 hash code expected to be 32 bytes"); slice.setBytes(HASH_OFFSET, modelHash); return slice; } public static HashCode modelHash(Slice slice) { return HashCode.fromBytes(slice.getBytes(HASH_OFFSET, 32)); } public static Model deserialize(byte[] data) { return deserialize(Slices.wrappedBuffer(data)); } public static Model deserialize(Slice slice) { int version = slice.getInt(VERSION_OFFSET); checkArgument(version == CURRENT_FORMAT_VERSION, "Unsupported version: %s", version); byte[] modelHashBytes = slice.getBytes(HASH_OFFSET, 32); HashCode expectedHash = HashCode.fromBytes(modelHashBytes); HashCode actualHash = Hashing.sha256().hashBytes(slice.getBytes(ALGORITHM_OFFSET, slice.length() - ALGORITHM_OFFSET)); checkArgument(actualHash.equals(expectedHash), "model hash does not match data"); int id = slice.getInt(ALGORITHM_OFFSET); Class algorithm = MODEL_SERIALIZATION_IDS.inverse().get(id); requireNonNull(algorithm, format("Unsupported algorith %d", id)); int hyperparameterLength = slice.getInt(HYPERPARAMETER_LENGTH_OFFSET); int dataLengthOffset = HYPERPARAMETERS_OFFSET + hyperparameterLength; long dataLength = slice.getLong(dataLengthOffset); int dataOffset = dataLengthOffset + SIZE_OF_LONG; byte[] data = slice.getBytes(dataOffset, (int) dataLength); try { Method deserialize = algorithm.getMethod("deserialize", byte[].class); return (Model) deserialize.invoke(null, new Object[] {data}); } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) { throw new RuntimeException(e); } } public static byte[] serializeModels(Model... models) { List serializedModels = new ArrayList<>(); int size = SIZE_OF_INT + SIZE_OF_INT * models.length; for (Model model : models) { byte[] bytes = serialize(model).getBytes(); size += bytes.length; serializedModels.add(bytes); } Slice slice = Slices.allocate(size); slice.setInt(0, models.length); for (int i = 0; i < models.length; i++) { slice.setInt(SIZE_OF_INT * (i + 1), serializedModels.get(i).length); } int offset = SIZE_OF_INT + SIZE_OF_INT * models.length; for (byte[] bytes : serializedModels) { slice.setBytes(offset, bytes); offset += bytes.length; } return slice.getBytes(); } public static List deserializeModels(byte[] bytes) { Slice slice = Slices.wrappedBuffer(bytes); int numModels = slice.getInt(0); int offset = SIZE_OF_INT + SIZE_OF_INT * numModels; ImmutableList.Builder models = ImmutableList.builder(); for (int i = 0; i < numModels; i++) { int length = slice.getInt(SIZE_OF_INT * (i + 1)); models.add(deserialize(slice.getBytes(offset, length))); offset += length; } return models.build(); } //TODO: instead of having this function, we should add feature extractors that extend Model and extract features from Strings public static FeatureVector toFeatures(SqlMap sqlMap) { Map features = new HashMap<>(); if (sqlMap != null) { int rawOffset = sqlMap.getRawOffset(); Block rawKeyBlock = sqlMap.getRawKeyBlock(); Block rawValueBlock = sqlMap.getRawValueBlock(); for (int i = 0; i < sqlMap.getSize(); i++) { features.put((int) BIGINT.getLong(rawKeyBlock, rawOffset + i), DOUBLE.getDouble(rawValueBlock, rawOffset + i)); } } return new FeatureVector(features); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy