All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.cassandra.db.marshal.VectorType Maven / Gradle / Ivy

Go to download

The Apache Cassandra Project develops a highly scalable second-generation distributed database, bringing together Dynamo's fully distributed design and Bigtable's ColumnFamily-based data model.

There is a newer version: 5.0.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.cassandra.db.marshal;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

import javax.annotation.Nullable;

import org.apache.cassandra.cql3.CQL3Type;
import org.apache.cassandra.cql3.Term;
import org.apache.cassandra.cql3.Vectors;
import org.apache.cassandra.db.TypeSizes;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.serializers.MarshalException;
import org.apache.cassandra.serializers.TypeSerializer;
import org.apache.cassandra.transport.ProtocolVersion;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.JsonUtils;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;
import org.apache.cassandra.utils.bytecomparable.ByteSource;

public final class VectorType extends AbstractType>
{
    private static class Key
    {
        private final AbstractType type;
        private final int dimension;

        private Key(AbstractType type, int dimension)
        {
            this.type = type;
            this.dimension = dimension;
        }

        private VectorType create()
        {
            return new VectorType<>(type, dimension);
        }

        @Override
        public boolean equals(Object o)
        {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            Key key = (Key) o;
            return dimension == key.dimension && Objects.equals(type, key.type);
        }

        @Override
        public int hashCode()
        {
            return Objects.hash(type, dimension);
        }
    }
    @SuppressWarnings("rawtypes")
    private static final ConcurrentHashMap instances = new ConcurrentHashMap<>();

    public final AbstractType elementType;
    public final int dimension;
    private final TypeSerializer elementSerializer;
    private final int valueLengthIfFixed;
    private final VectorSerializer serializer;

    private VectorType(AbstractType elementType, int dimension)
    {
        super(ComparisonType.CUSTOM);
        if (dimension <= 0)
            throw new InvalidRequestException(String.format("vectors may only have positive dimensions; given %d", dimension));
        this.elementType = elementType;
        this.dimension = dimension;
        this.elementSerializer = elementType.getSerializer();
        this.valueLengthIfFixed = elementType.isValueLengthFixed() ?
                                  elementType.valueLengthIfFixed() * dimension :
                                  super.valueLengthIfFixed();
        this.serializer = elementType.isValueLengthFixed() ?
                          new FixedLengthSerializer() :
                          new VariableLengthSerializer();
    }

    @SuppressWarnings("unchecked")
    public static  VectorType getInstance(AbstractType elements, int dimension)
    {
        Key key = new Key(elements, dimension);
        return instances.computeIfAbsent(key, Key::create);
    }

    public static VectorType getInstance(TypeParser parser)
    {
        TypeParser.Vector v = parser.getVectorParameters();
        return getInstance(v.type.freeze(), v.dimension);
    }

    @Override
    public boolean isVector()
    {
        return true;
    }

    @Override
    public  int compareCustom(VL left, ValueAccessor accessorL, VR right, ValueAccessor accessorR)
    {
        return getSerializer().compareCustom(left, accessorL, right, accessorR);
    }

    @Override
    public int valueLengthIfFixed()
    {
        return valueLengthIfFixed;
    }

    @Override
    public VectorSerializer getSerializer()
    {
        return serializer;
    }

    public List split(ByteBuffer buffer)
    {
        return split(buffer, ByteBufferAccessor.instance);
    }

    public  List split(V buffer, ValueAccessor accessor)
    {
        return getSerializer().split(buffer, accessor);
    }

    public float[] composeAsFloat(ByteBuffer input)
    {
        return composeAsFloat(input, ByteBufferAccessor.instance);
    }

    public  float[] composeAsFloat(V input, ValueAccessor accessor)
    {
        if (!(elementType instanceof FloatType))
            throw new IllegalStateException("Attempted to read as float, but element type is " + elementType.asCQL3Type());

        if (isNull(input, accessor))
            return null;

        return accessor.toFloatArray(input, dimension);
    }

    public ByteBuffer decompose(T... values)
    {
        return decompose(Arrays.asList(values));
    }

    public ByteBuffer decomposeAsFloat(float[] value)
    {
        return decomposeAsFloat(ByteBufferAccessor.instance, value);
    }

    public  V decomposeAsFloat(ValueAccessor accessor, float[] value)
    {
        if (value == null)
            rejectNullOrEmptyValue();
        if (!(elementType instanceof FloatType))
            throw new IllegalStateException("Attempted to read as float, but element type is " + elementType.asCQL3Type());
        if (value.length != dimension)
            throw new IllegalArgumentException(String.format("Attempted to add float vector of dimension %d to %s", value.length, asCQL3Type()));
        // TODO : should we use TypeSizes to be consistent with other code?  Its the same value at the end of the day...
        V buffer = accessor.allocate(Float.BYTES * dimension);
        int offset = 0;
        for (int i = 0; i < dimension; i++)
        {
            accessor.putFloat(buffer, offset, value[i]);
            offset+= Float.BYTES;
        }
        return buffer;
    }

    public ByteBuffer decomposeRaw(List elements)
    {
        return decomposeRaw(elements, ByteBufferAccessor.instance);
    }

    public  V decomposeRaw(List elements, ValueAccessor accessor)
    {
        return getSerializer().serializeRaw(elements, accessor);
    }

    @Override
    public  ByteSource asComparableBytes(ValueAccessor accessor, V value, ByteComparable.Version version)
    {
        if (isNull(value, accessor))
            return null;
        ByteSource[] srcs = new ByteSource[dimension];
        List split = split(value, accessor);
        for (int i = 0; i < dimension; i++)
            srcs[i] = elementType.asComparableBytes(accessor, split.get(i), version);
        return ByteSource.withTerminatorMaybeLegacy(version, 0x00, srcs);
    }

    @Override
    public  V fromComparableBytes(ValueAccessor accessor, ByteSource.Peekable comparableBytes, ByteComparable.Version version)
    {
        if (comparableBytes == null)
            rejectNullOrEmptyValue();

        assert version != ByteComparable.Version.LEGACY; // legacy translation is not reversible

        List buffers = new ArrayList<>();
        int separator = comparableBytes.next();
        while (separator != ByteSource.TERMINATOR)
        {
            buffers.add(elementType.fromComparableBytes(accessor, comparableBytes, version));
            separator = comparableBytes.next();
        }
        return decomposeRaw(buffers, accessor);
    }

    @Override
    public CQL3Type asCQL3Type()
    {
        return new CQL3Type.Vector(this);
    }

    public AbstractType getElementsType()
    {
        return elementType;
    }

    // vector of nested types is hard to parse, so fall back to bytes string matching ListType
    @Override
    public  String getString(V value, ValueAccessor accessor)
    {
        return BytesType.instance.getString(value, accessor);
    }

    @Override
    public ByteBuffer fromString(String source) throws MarshalException
    {
        try
        {
            return ByteBufferUtil.hexToBytes(source);
        }
        catch (NumberFormatException e)
        {
            throw new MarshalException(String.format("cannot parse '%s' as hex bytes", source), e);
        }
    }

    @Override
    public List> subTypes()
    {
        return Collections.singletonList(elementType);
    }

    @Override
    public String toJSONString(ByteBuffer buffer, ProtocolVersion protocolVersion)
    {
        return toJSONString(buffer, ByteBufferAccessor.instance, protocolVersion);
    }

    @Override
    public  String toJSONString(V value, ValueAccessor accessor, ProtocolVersion protocolVersion)
    {
        StringBuilder sb = new StringBuilder();
        sb.append('[');
        List split = split(value, accessor);
        for (int i = 0; i < dimension; i++)
        {
            if (i > 0)
                sb.append(", ");
            sb.append(elementType.toJSONString(split.get(i), accessor, protocolVersion));
        }
        sb.append(']');
        return sb.toString();
    }

    @Override
    public Term fromJSONObject(Object parsed) throws MarshalException
    {
        if (parsed instanceof String)
            parsed = JsonUtils.decodeJson((String) parsed);

        if (!(parsed instanceof List))
            throw new MarshalException(String.format(
            "Expected a list, but got a %s: %s", parsed.getClass().getSimpleName(), parsed));

        List list = (List) parsed;
        if (list.size() != dimension)
            throw new MarshalException(String.format("List had incorrect size: expected %d but given %d; %s", dimension, list.size(), list));
        List terms = new ArrayList<>(list.size());
        for (Object element : list)
        {
            if (element == null)
                throw new MarshalException("Invalid null element in list");
            terms.add(elementType.fromJSONObject(element));
        }

        return new Vectors.DelayedValue<>(this, terms);
    }

    @Override
    public boolean equals(Object o)
    {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        VectorType that = (VectorType) o;
        return dimension == that.dimension && Objects.equals(elementType, that.elementType);
    }

    @Override
    public int hashCode()
    {
        return Objects.hash(elementType, dimension);
    }

    @Override
    public String toString()
    {
        return toString(false);
    }

    @Override
    public String toString(boolean ignoreFreezing)
    {
        return getClass().getName() + TypeParser.stringifyVectorParameters(elementType, ignoreFreezing, dimension);
    }

    private void check(List values)
    {
        if (values.size() != dimension)
            throw new MarshalException(String.format("Required %d elements, but saw %d", dimension, values.size()));

        // This code base always works with a list that is RandomAccess, so can use .get to avoid allocation
        for (int i = 0; i < dimension; i++)
        {
            Object value = values.get(i);
            if (value == null || (value instanceof ByteBuffer && elementSerializer.isNull((ByteBuffer) value)))
                throw new MarshalException(String.format("Element at index %d is null (expected type %s); given %s", i, elementType.asCQL3Type(), values));
        }
    }

    private  void checkConsumedFully(V buffer, ValueAccessor accessor, int offset)
    {
        int remaining = accessor.sizeFromOffset(buffer, offset);
        if (remaining > 0)
            throw new MarshalException("Unexpected " + remaining + " extraneous bytes after " + asCQL3Type() + " value");
    }
    
    private static void rejectNullOrEmptyValue()
    {
        throw new MarshalException("Invalid empty vector value");
    }

    @Override
    public ByteBuffer getMaskedValue()
    {
        List values = Collections.nCopies(dimension, elementType.getMaskedValue());
        return serializer.serializeRaw(values, ByteBufferAccessor.instance);
    }

    public abstract class VectorSerializer extends TypeSerializer>
    {
        public abstract  int compareCustom(VL left, ValueAccessor accessorL, VR right, ValueAccessor accessorR);

        public abstract  List split(V buffer, ValueAccessor accessor);
        public abstract  V serializeRaw(List elements, ValueAccessor accessor);

        @Override
        public String toString(List value)
        {
            StringBuilder sb = new StringBuilder();
            boolean isFirst = true;
            sb.append('[');
            for (T element : value)
            {
                if (isFirst)
                    isFirst = false;
                else
                    sb.append(", ");
                sb.append(elementSerializer.toString(element));
            }
            sb.append(']');
            return sb.toString();
        }

        @Override
        @SuppressWarnings({ "rawtypes", "unchecked" })
        public Class> getType()
        {
            return (Class) List.class;
        }

        @Override
        public  boolean isNull(@Nullable V buffer, ValueAccessor accessor)
        {
            // we don't allow empty vectors, so we can just check for null
            return buffer == null;
        }
    }

    private class FixedLengthSerializer extends VectorSerializer
    {
        private FixedLengthSerializer()
        {
        }

        @Override
        public  int compareCustom(VL left, ValueAccessor accessorL,
                                          VR right, ValueAccessor accessorR)
        {
            if (elementType.isByteOrderComparable)
                return ValueAccessor.compare(left, accessorL, right, accessorR);
            int offset = 0;
            int elementLength = elementType.valueLengthIfFixed();
            for (int i = 0; i < dimension; i++)
            {
                VL leftBytes = accessorL.slice(left, offset, elementLength);
                VR rightBytes = accessorR.slice(right, offset, elementLength);
                int rc = elementType.compare(leftBytes, accessorL, rightBytes, accessorR);
                if (rc != 0)
                    return rc;

                offset += elementLength;
            }
            return 0;
        }

        @Override
        public  List split(V buffer, ValueAccessor accessor)
        {
            List result = new ArrayList<>(dimension);
            int offset = 0;
            int elementLength = elementType.valueLengthIfFixed();
            for (int i = 0; i < dimension; i++)
            {
                V bb = accessor.slice(buffer, offset, elementLength);
                offset += elementLength;
                elementSerializer.validate(bb, accessor);
                result.add(bb);
            }
            checkConsumedFully(buffer, accessor, offset);

            return result;
        }

        @Override
        public  V serializeRaw(List value, ValueAccessor accessor)
        {
            if (value == null)
                rejectNullOrEmptyValue();

            check(value);

            int size = elementType.valueLengthIfFixed();
            V bb = accessor.allocate(size * dimension);
            int position = 0;
            for (V v : value)
                position += accessor.copyTo(v, 0, bb, accessor, position, size);
            return bb;
        }

        @Override
        public ByteBuffer serialize(List value)
        {
            if (value == null)
                rejectNullOrEmptyValue();

            check(value);

            ByteBuffer bb = ByteBuffer.allocate(elementType.valueLengthIfFixed() * dimension);
            for (T v : value)
                bb.put(elementSerializer.serialize(v).duplicate());
            bb.flip();
            return bb;
        }

        @Override
        public  List deserialize(V input, ValueAccessor accessor)
        {
            if (isNull(input, accessor))
                return null;
            List result = new ArrayList<>(dimension);
            int offset = 0;
            int elementLength = elementType.valueLengthIfFixed();
            for (int i = 0; i < dimension; i++)
            {
                V bb = accessor.slice(input, offset, elementLength);
                offset += elementLength;
                elementSerializer.validate(bb, accessor);
                result.add(elementSerializer.deserialize(bb, accessor));
            }
            checkConsumedFully(input, accessor, offset);

            return result;
        }

        @Override
        public  void validate(V input, ValueAccessor accessor) throws MarshalException
        {
            if (accessor.isEmpty(input))
                rejectNullOrEmptyValue();

            int offset = 0;
            int elementSize = elementType.valueLengthIfFixed();

            int expectedSize = elementSize * dimension;
            if (accessor.size(input) < expectedSize)
                throw new MarshalException("Not enough bytes to read a " + asCQL3Type());

            for (int i = 0; i < dimension; i++)
            {
                V bb = accessor.slice(input, offset, elementSize);
                offset += elementSize;
                elementSerializer.validate(bb, accessor);
            }
            checkConsumedFully(input, accessor, offset);
        }
    }

    private class VariableLengthSerializer extends VectorSerializer
    {
        private VariableLengthSerializer()
        {
        }

        @Override
        public  int compareCustom(VL left, ValueAccessor accessorL,
                                          VR right, ValueAccessor accessorR)
        {
            int leftOffset = 0;
            int rightOffset = 0;
            for (int i = 0; i < dimension; i++)
            {
                VL leftBytes = readValue(left, accessorL, leftOffset);
                leftOffset += sizeOf(leftBytes, accessorL);

                VR rightBytes = readValue(right, accessorR, rightOffset);
                rightOffset += sizeOf(rightBytes, accessorR);

                int rc = elementType.compare(leftBytes, accessorL, rightBytes, accessorR);
                if (rc != 0)
                    return rc;
            }
            return 0;
        }

        private  V readValue(V input, ValueAccessor accessor, int offset)
        {
            int size = accessor.getUnsignedVInt32(input, offset);
            if (size < 0)
                throw new AssertionError("Invalidate data at offset " + offset + "; saw size of " + size + " but only >= 0 is expected");

            return accessor.slice(input, offset + TypeSizes.sizeofUnsignedVInt(size), size);
        }

        private  int writeValue(V src, V dst, ValueAccessor accessor, int offset)
        {
            int size = accessor.size(src);
            int written = 0;
            written += accessor.putUnsignedVInt32(dst, offset + written, size);
            written += accessor.copyTo(src, 0, dst, accessor, offset + written, size);
            return written;
        }

        private  int sizeOf(V bb, ValueAccessor accessor)
        {
            return accessor.sizeWithVIntLength(bb);
        }

        @Override
        public  List split(V buffer, ValueAccessor accessor)
        {
            List result = new ArrayList<>(dimension);
            int offset = 0;
            for (int i = 0; i < dimension; i++)
            {
                V bb = readValue(buffer, accessor, offset);
                offset += sizeOf(bb, accessor);
                elementSerializer.validate(bb, accessor);
                result.add(bb);
            }
            checkConsumedFully(buffer, accessor, offset);

            return result;
        }

        @Override
        public  V serializeRaw(List value, ValueAccessor accessor)
        {
            if (value == null)
                rejectNullOrEmptyValue();

            check(value);

            V bb = accessor.allocate(value.stream().mapToInt(v -> sizeOf(v, accessor)).sum());
            int offset = 0;
            for (V b : value)
                offset += writeValue(b, bb, accessor, offset);
            return bb;
        }

        @Override
        public ByteBuffer serialize(List value)
        {
            if (value == null)
                rejectNullOrEmptyValue();

            check(value);

            List bbs = new ArrayList<>(dimension);
            for (int i = 0; i < dimension; i++)
                bbs.add(elementSerializer.serialize(value.get(i)));
            return serializeRaw(bbs, ByteBufferAccessor.instance);
        }

        @Override
        public  List deserialize(V input, ValueAccessor accessor)
        {
            if (isNull(input, accessor))
                return null;
            List result = new ArrayList<>(dimension);
            int offset = 0;
            for (int i = 0; i < dimension; i++)
            {
                V bb = readValue(input, accessor, offset);
                offset += sizeOf(bb, accessor);
                elementSerializer.validate(bb, accessor);
                result.add(elementSerializer.deserialize(bb, accessor));
            }
            checkConsumedFully(input, accessor, offset);

            return result;
        }

        @Override
        public  void validate(V input, ValueAccessor accessor) throws MarshalException
        {
            if (accessor.isEmpty(input))
                rejectNullOrEmptyValue();

            int offset = 0;
            for (int i = 0; i < dimension; i++)
            {
                if (offset >= accessor.size(input))
                    throw new MarshalException("Not enough bytes to read a " + asCQL3Type());

                V bb = readValue(input, accessor, offset);
                offset += sizeOf(bb, accessor);
                elementSerializer.validate(bb, accessor);
            }
            checkConsumedFully(input, accessor, offset);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy