All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.datavec.api.writable.BytesWritable Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.datavec.api.writable;

import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;

@NoArgsConstructor
public class BytesWritable extends ArrayWritable {
    @Getter
    @Setter
    private byte[] content;

    private transient ByteBuffer cached;

    /**
     * Pass in the content for this writable
     * @param content the content for this writable
     */
    public BytesWritable(byte[] content) {
        this.content = content;
    }

    /**
     * Convert the underlying contents of this {@link Writable}
     * to an nd4j {@link DataBuffer}. Note that this is a *copy*
     * of the underlying buffer.
     * Also note that {@link java.nio.ByteBuffer#allocateDirect(int)}
     * is used for allocation.
     * This should be considered an expensive operation.
     *
     * This buffer should be cached when used. Once used, this can be
     * used in standard Nd4j operations.
     *
     * Beyond that, the reason we have to use allocateDirect
     * is due to nd4j data buffers being stored off heap (whether on cpu or gpu)
     * @param type the type of the data buffer
     * @param elementSize the size of each element in the buffer
     * @return the equivalent nd4j data buffer
     */
    public DataBuffer asNd4jBuffer(DataType type, int elementSize) {
        int length = content.length / elementSize;
        DataBuffer ret = Nd4j.createBuffer(ByteBuffer.allocateDirect(content.length),type,length,0);
        for(int i = 0; i < length; i++) {
            switch(type) {
                case DOUBLE:
                    ret.put(i,getDouble(i));
                    break;
                case INT:
                    ret.put(i,getInt(i));
                    break;
                case FLOAT:
                    ret.put(i,getFloat(i));
                   break;
                case LONG:
                    ret.put(i,getLong(i));
                    break;
            }
        }
        return ret;
    }

    @Override
    public long length() {
        return content.length;
    }

    @Override
    public double getDouble(long i) {
        return cachedByteByteBuffer().getDouble((int) i * 8);
    }

    @Override
    public float getFloat(long i) {
        return cachedByteByteBuffer().getFloat((int) i * 4);
    }

    @Override
    public int getInt(long i) {
        return cachedByteByteBuffer().getInt((int) i * 4);
    }

    @Override
    public long getLong(long i) {
        return cachedByteByteBuffer().getLong((int) i * 8);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        out.write(content);
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        in.readFully(content);
    }

    @Override
    public void writeType(DataOutput out) throws IOException {
        out.writeShort(getType().typeIdx());
    }

    @Override
    public WritableType getType() {
        return WritableType.Bytes;
    }

    private ByteBuffer cachedByteByteBuffer() {
        if(cached == null) {
            cached = ByteBuffer.wrap(content);
        }
        return cached;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        BytesWritable that = (BytesWritable) o;
        return Arrays.equals(content, that.content);
    }

    @Override
    public int hashCode() {
        return Arrays.hashCode(content);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy