org.datavec.api.writable.BytesWritable Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.writable;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
@NoArgsConstructor
public class BytesWritable extends ArrayWritable {
@Getter
@Setter
private byte[] content;
private transient ByteBuffer cached;
/**
* Pass in the content for this writable
* @param content the content for this writable
*/
public BytesWritable(byte[] content) {
this.content = content;
}
/**
* Convert the underlying contents of this {@link Writable}
* to an nd4j {@link DataBuffer}. Note that this is a *copy*
* of the underlying buffer.
* Also note that {@link java.nio.ByteBuffer#allocateDirect(int)}
* is used for allocation.
* This should be considered an expensive operation.
*
* This buffer should be cached when used. Once used, this can be
* used in standard Nd4j operations.
*
* Beyond that, the reason we have to use allocateDirect
* is due to nd4j data buffers being stored off heap (whether on cpu or gpu)
* @param type the type of the data buffer
* @param elementSize the size of each element in the buffer
* @return the equivalent nd4j data buffer
*/
public DataBuffer asNd4jBuffer(DataType type, int elementSize) {
int length = content.length / elementSize;
DataBuffer ret = Nd4j.createBuffer(ByteBuffer.allocateDirect(content.length),type,length,0);
for(int i = 0; i < length; i++) {
switch(type) {
case DOUBLE:
ret.put(i,getDouble(i));
break;
case INT:
ret.put(i,getInt(i));
break;
case FLOAT:
ret.put(i,getFloat(i));
break;
case LONG:
ret.put(i,getLong(i));
break;
}
}
return ret;
}
@Override
public long length() {
return content.length;
}
@Override
public double getDouble(long i) {
return cachedByteByteBuffer().getDouble((int) i * 8);
}
@Override
public float getFloat(long i) {
return cachedByteByteBuffer().getFloat((int) i * 4);
}
@Override
public int getInt(long i) {
return cachedByteByteBuffer().getInt((int) i * 4);
}
@Override
public long getLong(long i) {
return cachedByteByteBuffer().getLong((int) i * 8);
}
@Override
public void write(DataOutput out) throws IOException {
out.write(content);
}
@Override
public void readFields(DataInput in) throws IOException {
in.readFully(content);
}
@Override
public void writeType(DataOutput out) throws IOException {
out.writeShort(getType().typeIdx());
}
@Override
public WritableType getType() {
return WritableType.Bytes;
}
private ByteBuffer cachedByteByteBuffer() {
if(cached == null) {
cached = ByteBuffer.wrap(content);
}
return cached;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BytesWritable that = (BytesWritable) o;
return Arrays.equals(content, that.content);
}
@Override
public int hashCode() {
return Arrays.hashCode(content);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy