org.apache.cassandra.db.marshal.VectorType Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cassandra-all Show documentation
Show all versions of cassandra-all Show documentation
The Apache Cassandra Project develops a highly scalable second-generation distributed database, bringing together Dynamo's fully distributed design and Bigtable's ColumnFamily-based data model.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.cassandra.db.marshal;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import javax.annotation.Nullable;
import org.apache.cassandra.cql3.CQL3Type;
import org.apache.cassandra.cql3.Term;
import org.apache.cassandra.cql3.Vectors;
import org.apache.cassandra.db.TypeSizes;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.serializers.MarshalException;
import org.apache.cassandra.serializers.TypeSerializer;
import org.apache.cassandra.transport.ProtocolVersion;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.JsonUtils;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;
import org.apache.cassandra.utils.bytecomparable.ByteSource;
public final class VectorType extends AbstractType>
{
private static class Key
{
private final AbstractType type;
private final int dimension;
private Key(AbstractType type, int dimension)
{
this.type = type;
this.dimension = dimension;
}
private VectorType create()
{
return new VectorType<>(type, dimension);
}
@Override
public boolean equals(Object o)
{
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Key key = (Key) o;
return dimension == key.dimension && Objects.equals(type, key.type);
}
@Override
public int hashCode()
{
return Objects.hash(type, dimension);
}
}
@SuppressWarnings("rawtypes")
private static final ConcurrentHashMap instances = new ConcurrentHashMap<>();
public final AbstractType elementType;
public final int dimension;
private final TypeSerializer elementSerializer;
private final int valueLengthIfFixed;
private final VectorSerializer serializer;
private VectorType(AbstractType elementType, int dimension)
{
super(ComparisonType.CUSTOM);
if (dimension <= 0)
throw new InvalidRequestException(String.format("vectors may only have positive dimensions; given %d", dimension));
this.elementType = elementType;
this.dimension = dimension;
this.elementSerializer = elementType.getSerializer();
this.valueLengthIfFixed = elementType.isValueLengthFixed() ?
elementType.valueLengthIfFixed() * dimension :
super.valueLengthIfFixed();
this.serializer = elementType.isValueLengthFixed() ?
new FixedLengthSerializer() :
new VariableLengthSerializer();
}
@SuppressWarnings("unchecked")
public static VectorType getInstance(AbstractType elements, int dimension)
{
Key key = new Key(elements, dimension);
return instances.computeIfAbsent(key, Key::create);
}
public static VectorType getInstance(TypeParser parser)
{
TypeParser.Vector v = parser.getVectorParameters();
return getInstance(v.type.freeze(), v.dimension);
}
@Override
public boolean isVector()
{
return true;
}
@Override
public int compareCustom(VL left, ValueAccessor accessorL, VR right, ValueAccessor accessorR)
{
return getSerializer().compareCustom(left, accessorL, right, accessorR);
}
@Override
public int valueLengthIfFixed()
{
return valueLengthIfFixed;
}
@Override
public VectorSerializer getSerializer()
{
return serializer;
}
public List split(ByteBuffer buffer)
{
return split(buffer, ByteBufferAccessor.instance);
}
public List split(V buffer, ValueAccessor accessor)
{
return getSerializer().split(buffer, accessor);
}
public float[] composeAsFloat(ByteBuffer input)
{
return composeAsFloat(input, ByteBufferAccessor.instance);
}
public float[] composeAsFloat(V input, ValueAccessor accessor)
{
if (!(elementType instanceof FloatType))
throw new IllegalStateException("Attempted to read as float, but element type is " + elementType.asCQL3Type());
if (isNull(input, accessor))
return null;
return accessor.toFloatArray(input, dimension);
}
public ByteBuffer decompose(T... values)
{
return decompose(Arrays.asList(values));
}
public ByteBuffer decomposeAsFloat(float[] value)
{
return decomposeAsFloat(ByteBufferAccessor.instance, value);
}
public V decomposeAsFloat(ValueAccessor accessor, float[] value)
{
if (value == null)
rejectNullOrEmptyValue();
if (!(elementType instanceof FloatType))
throw new IllegalStateException("Attempted to read as float, but element type is " + elementType.asCQL3Type());
if (value.length != dimension)
throw new IllegalArgumentException(String.format("Attempted to add float vector of dimension %d to %s", value.length, asCQL3Type()));
// TODO : should we use TypeSizes to be consistent with other code? Its the same value at the end of the day...
V buffer = accessor.allocate(Float.BYTES * dimension);
int offset = 0;
for (int i = 0; i < dimension; i++)
{
accessor.putFloat(buffer, offset, value[i]);
offset+= Float.BYTES;
}
return buffer;
}
public ByteBuffer decomposeRaw(List elements)
{
return decomposeRaw(elements, ByteBufferAccessor.instance);
}
public V decomposeRaw(List elements, ValueAccessor accessor)
{
return getSerializer().serializeRaw(elements, accessor);
}
@Override
public ByteSource asComparableBytes(ValueAccessor accessor, V value, ByteComparable.Version version)
{
if (isNull(value, accessor))
return null;
ByteSource[] srcs = new ByteSource[dimension];
List split = split(value, accessor);
for (int i = 0; i < dimension; i++)
srcs[i] = elementType.asComparableBytes(accessor, split.get(i), version);
return ByteSource.withTerminatorMaybeLegacy(version, 0x00, srcs);
}
@Override
public V fromComparableBytes(ValueAccessor accessor, ByteSource.Peekable comparableBytes, ByteComparable.Version version)
{
if (comparableBytes == null)
rejectNullOrEmptyValue();
assert version != ByteComparable.Version.LEGACY; // legacy translation is not reversible
List buffers = new ArrayList<>();
int separator = comparableBytes.next();
while (separator != ByteSource.TERMINATOR)
{
buffers.add(elementType.fromComparableBytes(accessor, comparableBytes, version));
separator = comparableBytes.next();
}
return decomposeRaw(buffers, accessor);
}
@Override
public CQL3Type asCQL3Type()
{
return new CQL3Type.Vector(this);
}
public AbstractType getElementsType()
{
return elementType;
}
// vector of nested types is hard to parse, so fall back to bytes string matching ListType
@Override
public String getString(V value, ValueAccessor accessor)
{
return BytesType.instance.getString(value, accessor);
}
@Override
public ByteBuffer fromString(String source) throws MarshalException
{
try
{
return ByteBufferUtil.hexToBytes(source);
}
catch (NumberFormatException e)
{
throw new MarshalException(String.format("cannot parse '%s' as hex bytes", source), e);
}
}
@Override
public List> subTypes()
{
return Collections.singletonList(elementType);
}
@Override
public String toJSONString(ByteBuffer buffer, ProtocolVersion protocolVersion)
{
return toJSONString(buffer, ByteBufferAccessor.instance, protocolVersion);
}
@Override
public String toJSONString(V value, ValueAccessor accessor, ProtocolVersion protocolVersion)
{
StringBuilder sb = new StringBuilder();
sb.append('[');
List split = split(value, accessor);
for (int i = 0; i < dimension; i++)
{
if (i > 0)
sb.append(", ");
sb.append(elementType.toJSONString(split.get(i), accessor, protocolVersion));
}
sb.append(']');
return sb.toString();
}
@Override
public Term fromJSONObject(Object parsed) throws MarshalException
{
if (parsed instanceof String)
parsed = JsonUtils.decodeJson((String) parsed);
if (!(parsed instanceof List))
throw new MarshalException(String.format(
"Expected a list, but got a %s: %s", parsed.getClass().getSimpleName(), parsed));
List list = (List) parsed;
if (list.size() != dimension)
throw new MarshalException(String.format("List had incorrect size: expected %d but given %d; %s", dimension, list.size(), list));
List terms = new ArrayList<>(list.size());
for (Object element : list)
{
if (element == null)
throw new MarshalException("Invalid null element in list");
terms.add(elementType.fromJSONObject(element));
}
return new Vectors.DelayedValue<>(this, terms);
}
@Override
public boolean equals(Object o)
{
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
VectorType that = (VectorType) o;
return dimension == that.dimension && Objects.equals(elementType, that.elementType);
}
@Override
public int hashCode()
{
return Objects.hash(elementType, dimension);
}
@Override
public String toString()
{
return toString(false);
}
@Override
public String toString(boolean ignoreFreezing)
{
return getClass().getName() + TypeParser.stringifyVectorParameters(elementType, ignoreFreezing, dimension);
}
private void check(List values)
{
if (values.size() != dimension)
throw new MarshalException(String.format("Required %d elements, but saw %d", dimension, values.size()));
// This code base always works with a list that is RandomAccess, so can use .get to avoid allocation
for (int i = 0; i < dimension; i++)
{
Object value = values.get(i);
if (value == null || (value instanceof ByteBuffer && elementSerializer.isNull((ByteBuffer) value)))
throw new MarshalException(String.format("Element at index %d is null (expected type %s); given %s", i, elementType.asCQL3Type(), values));
}
}
private void checkConsumedFully(V buffer, ValueAccessor accessor, int offset)
{
int remaining = accessor.sizeFromOffset(buffer, offset);
if (remaining > 0)
throw new MarshalException("Unexpected " + remaining + " extraneous bytes after " + asCQL3Type() + " value");
}
private static void rejectNullOrEmptyValue()
{
throw new MarshalException("Invalid empty vector value");
}
@Override
public ByteBuffer getMaskedValue()
{
List values = Collections.nCopies(dimension, elementType.getMaskedValue());
return serializer.serializeRaw(values, ByteBufferAccessor.instance);
}
public abstract class VectorSerializer extends TypeSerializer>
{
public abstract int compareCustom(VL left, ValueAccessor accessorL, VR right, ValueAccessor accessorR);
public abstract List split(V buffer, ValueAccessor accessor);
public abstract V serializeRaw(List elements, ValueAccessor accessor);
@Override
public String toString(List value)
{
StringBuilder sb = new StringBuilder();
boolean isFirst = true;
sb.append('[');
for (T element : value)
{
if (isFirst)
isFirst = false;
else
sb.append(", ");
sb.append(elementSerializer.toString(element));
}
sb.append(']');
return sb.toString();
}
@Override
@SuppressWarnings({ "rawtypes", "unchecked" })
public Class> getType()
{
return (Class) List.class;
}
@Override
public boolean isNull(@Nullable V buffer, ValueAccessor accessor)
{
// we don't allow empty vectors, so we can just check for null
return buffer == null;
}
}
private class FixedLengthSerializer extends VectorSerializer
{
private FixedLengthSerializer()
{
}
@Override
public int compareCustom(VL left, ValueAccessor accessorL,
VR right, ValueAccessor accessorR)
{
if (elementType.isByteOrderComparable)
return ValueAccessor.compare(left, accessorL, right, accessorR);
int offset = 0;
int elementLength = elementType.valueLengthIfFixed();
for (int i = 0; i < dimension; i++)
{
VL leftBytes = accessorL.slice(left, offset, elementLength);
VR rightBytes = accessorR.slice(right, offset, elementLength);
int rc = elementType.compare(leftBytes, accessorL, rightBytes, accessorR);
if (rc != 0)
return rc;
offset += elementLength;
}
return 0;
}
@Override
public List split(V buffer, ValueAccessor accessor)
{
List result = new ArrayList<>(dimension);
int offset = 0;
int elementLength = elementType.valueLengthIfFixed();
for (int i = 0; i < dimension; i++)
{
V bb = accessor.slice(buffer, offset, elementLength);
offset += elementLength;
elementSerializer.validate(bb, accessor);
result.add(bb);
}
checkConsumedFully(buffer, accessor, offset);
return result;
}
@Override
public V serializeRaw(List value, ValueAccessor accessor)
{
if (value == null)
rejectNullOrEmptyValue();
check(value);
int size = elementType.valueLengthIfFixed();
V bb = accessor.allocate(size * dimension);
int position = 0;
for (V v : value)
position += accessor.copyTo(v, 0, bb, accessor, position, size);
return bb;
}
@Override
public ByteBuffer serialize(List value)
{
if (value == null)
rejectNullOrEmptyValue();
check(value);
ByteBuffer bb = ByteBuffer.allocate(elementType.valueLengthIfFixed() * dimension);
for (T v : value)
bb.put(elementSerializer.serialize(v).duplicate());
bb.flip();
return bb;
}
@Override
public List deserialize(V input, ValueAccessor accessor)
{
if (isNull(input, accessor))
return null;
List result = new ArrayList<>(dimension);
int offset = 0;
int elementLength = elementType.valueLengthIfFixed();
for (int i = 0; i < dimension; i++)
{
V bb = accessor.slice(input, offset, elementLength);
offset += elementLength;
elementSerializer.validate(bb, accessor);
result.add(elementSerializer.deserialize(bb, accessor));
}
checkConsumedFully(input, accessor, offset);
return result;
}
@Override
public void validate(V input, ValueAccessor accessor) throws MarshalException
{
if (accessor.isEmpty(input))
rejectNullOrEmptyValue();
int offset = 0;
int elementSize = elementType.valueLengthIfFixed();
int expectedSize = elementSize * dimension;
if (accessor.size(input) < expectedSize)
throw new MarshalException("Not enough bytes to read a " + asCQL3Type());
for (int i = 0; i < dimension; i++)
{
V bb = accessor.slice(input, offset, elementSize);
offset += elementSize;
elementSerializer.validate(bb, accessor);
}
checkConsumedFully(input, accessor, offset);
}
}
private class VariableLengthSerializer extends VectorSerializer
{
private VariableLengthSerializer()
{
}
@Override
public int compareCustom(VL left, ValueAccessor accessorL,
VR right, ValueAccessor accessorR)
{
int leftOffset = 0;
int rightOffset = 0;
for (int i = 0; i < dimension; i++)
{
VL leftBytes = readValue(left, accessorL, leftOffset);
leftOffset += sizeOf(leftBytes, accessorL);
VR rightBytes = readValue(right, accessorR, rightOffset);
rightOffset += sizeOf(rightBytes, accessorR);
int rc = elementType.compare(leftBytes, accessorL, rightBytes, accessorR);
if (rc != 0)
return rc;
}
return 0;
}
private V readValue(V input, ValueAccessor accessor, int offset)
{
int size = accessor.getUnsignedVInt32(input, offset);
if (size < 0)
throw new AssertionError("Invalidate data at offset " + offset + "; saw size of " + size + " but only >= 0 is expected");
return accessor.slice(input, offset + TypeSizes.sizeofUnsignedVInt(size), size);
}
private int writeValue(V src, V dst, ValueAccessor accessor, int offset)
{
int size = accessor.size(src);
int written = 0;
written += accessor.putUnsignedVInt32(dst, offset + written, size);
written += accessor.copyTo(src, 0, dst, accessor, offset + written, size);
return written;
}
private int sizeOf(V bb, ValueAccessor accessor)
{
return accessor.sizeWithVIntLength(bb);
}
@Override
public List split(V buffer, ValueAccessor accessor)
{
List result = new ArrayList<>(dimension);
int offset = 0;
for (int i = 0; i < dimension; i++)
{
V bb = readValue(buffer, accessor, offset);
offset += sizeOf(bb, accessor);
elementSerializer.validate(bb, accessor);
result.add(bb);
}
checkConsumedFully(buffer, accessor, offset);
return result;
}
@Override
public V serializeRaw(List value, ValueAccessor accessor)
{
if (value == null)
rejectNullOrEmptyValue();
check(value);
V bb = accessor.allocate(value.stream().mapToInt(v -> sizeOf(v, accessor)).sum());
int offset = 0;
for (V b : value)
offset += writeValue(b, bb, accessor, offset);
return bb;
}
@Override
public ByteBuffer serialize(List value)
{
if (value == null)
rejectNullOrEmptyValue();
check(value);
List bbs = new ArrayList<>(dimension);
for (int i = 0; i < dimension; i++)
bbs.add(elementSerializer.serialize(value.get(i)));
return serializeRaw(bbs, ByteBufferAccessor.instance);
}
@Override
public List deserialize(V input, ValueAccessor accessor)
{
if (isNull(input, accessor))
return null;
List result = new ArrayList<>(dimension);
int offset = 0;
for (int i = 0; i < dimension; i++)
{
V bb = readValue(input, accessor, offset);
offset += sizeOf(bb, accessor);
elementSerializer.validate(bb, accessor);
result.add(elementSerializer.deserialize(bb, accessor));
}
checkConsumedFully(input, accessor, offset);
return result;
}
@Override
public void validate(V input, ValueAccessor accessor) throws MarshalException
{
if (accessor.isEmpty(input))
rejectNullOrEmptyValue();
int offset = 0;
for (int i = 0; i < dimension; i++)
{
if (offset >= accessor.size(input))
throw new MarshalException("Not enough bytes to read a " + asCQL3Type());
V bb = readValue(input, accessor, offset);
offset += sizeOf(bb, accessor);
elementSerializer.validate(bb, accessor);
}
checkConsumedFully(input, accessor, offset);
}
}
}