All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.io.Arrow Maven / Gradle / Ivy

There is a newer version: 2.6.0
Show newest version
/*******************************************************************************
 * Copyright (c) 2010-2020 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation, either version 3 of
 * the License, or (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with Smile.  If not, see .
 ******************************************************************************/

package smile.io;

import java.io.InputStream;
import java.io.OutputStream;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.math.BigDecimal;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.*;
import java.util.Arrays;
import java.util.List;
import java.util.ArrayList;
import java.util.stream.Collectors;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import static org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE;
import static org.apache.arrow.vector.types.FloatingPointPrecision.SINGLE;
import smile.data.DataFrame;
import smile.data.type.*;

/**
 * Apache Arrow is a cross-language development platform for in-memory data.
 * It specifies a standardized language-independent columnar memory format
 * for flat and hierarchical data, organized for efficient analytic
 * operations on modern hardware.
 *
 * @author Haifeng Li
 */
public class Arrow {
    private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(Arrow.class);

    /**
     * The root allocator. Typically only one created for a JVM.
     * Arrow provides a tree-based model for memory allocation.
     * The RootAllocator is created first, then all allocators are
     * created as children of that allocator. The RootAllocator is
     * responsible for being the master bookeeper for memory
     * allocations. All other allocators are created as children
     * of this tree. Each allocator can first determine whether
     * it has enough local memory to satisfy a particular request.
     * If not, the allocator can ask its parent for an additional
     * memory allocation.
     */
    private static RootAllocator allocator;
    /**
     * The number of records in a record batch.
     * An Apache Arrow record batch is conceptually similar
     * to the Parquet row group. Parquet recommends a
     * disk/block/row group/file size of 512 to 1024 MB on HDFS.
     * 1 million rows x 100 columns of double will be about
     * 800 MB and also cover many use cases in machine learning.
     */
    private int batch;

    /** Constructor. */
    public Arrow() {
        this(1000000);
    }

    /**
     * Constructor.
     * @param batch the number of records in a record batch.
     */
    public Arrow(int batch) {
        if (batch <= 0) {
            throw new IllegalArgumentException("Invalid batch size: " + batch);
        }

        this.batch = batch;
    }

    /**
     * Creates the root allocator.
     * The RootAllocator is responsible for being the master
     * bookeeper for memory allocations.
     *
     * @param limit memory allocation limit in bytes.
     */
    public static void allocate(long limit) {
        if (limit <= 0) {
            throw new IllegalArgumentException("Invalid RootAllocator limit: " + limit);
        }

        allocator = new RootAllocator(limit);
    }

    /**
     * Reads an arrow file.
     * @param path an Apache Arrow file path.
     */
    public DataFrame read(Path path) throws IOException {
        return read(path, Integer.MAX_VALUE);
    }

    /**
     * Reads an arrow file.
     * @param path an Apache Arrow file path.
     */
    public DataFrame read(Path path, int limit) throws IOException {
        return read(Files.newInputStream(path), limit);
    }

    /**
     * Reads a limited number of records from an arrow file.
     * @param path an Apache Arrow file path or URI.
     */
    public DataFrame read(String path) throws IOException, URISyntaxException {
        return read(path, Integer.MAX_VALUE);
    }

    /**
     * Reads a limited number of records from an arrow file.
     * @param path an Apache Arrow file path or URI.
     * @param limit reads a limited number of records.
     */
    public DataFrame read(String path, int limit) throws IOException, URISyntaxException {
        return read(Input.stream(path), limit);
    }

    /**
     * Reads a limited number of records from an arrow file.
     * @param input an Apache Arrow file input stream.
     * @param limit reads a limited number of records.
     */
    public DataFrame read(InputStream input, int limit) throws IOException {
        if (allocator == null) {
            allocate(Long.MAX_VALUE);
        }

        try (ArrowStreamReader reader = new ArrowStreamReader(input, allocator)) {

            // The holder for a set of vectors to be loaded/unloaded.
            VectorSchemaRoot root = reader.getVectorSchemaRoot();
            List frames = new ArrayList<>();
            int size = 0;
            while (reader.loadNextBatch() && size < limit) {
                List fieldVectors = root.getFieldVectors();
                logger.info("read {} rows and {} columns", root.getRowCount(), fieldVectors.size());

                smile.data.vector.BaseVector[] vectors = new smile.data.vector.BaseVector[fieldVectors.size()];
                for (int j = 0; j < fieldVectors.size(); j++) {
                    FieldVector fieldVector = fieldVectors.get(j);
                    ArrowType type = fieldVector.getField().getType();
                    switch (type.getTypeID()) {
                        case Int:
                            ArrowType.Int itype = (ArrowType.Int) type;
                            int bitWidth = itype.getBitWidth();
                            switch (bitWidth) {
                                case 8:
                                    vectors[j] = readByteField(fieldVector);
                                    break;
                                case 16:
                                    if (itype.getIsSigned())
                                        vectors[j] = readShortField(fieldVector);
                                    else
                                        vectors[j] = readCharField(fieldVector);
                                    break;
                                case 32:
                                    vectors[j] = readIntField(fieldVector);
                                    break;
                                case 64:
                                    vectors[j] = readLongField(fieldVector);
                                    break;
                                default:
                                    throw new UnsupportedOperationException("Unsupported integer bit width: " + bitWidth);
                            }
                            break;
                        case FloatingPoint:
                            FloatingPointPrecision precision = ((ArrowType.FloatingPoint) type).getPrecision();
                            switch (precision) {
                                case DOUBLE:
                                    vectors[j] = readDoubleField(fieldVector);
                                    break;
                                case SINGLE:
                                    vectors[j] = readFloatField(fieldVector);
                                    break;
                                case HALF:
                                    throw new UnsupportedOperationException("Unsupported float precision: " + precision);
                            }
                            break;
                        case Decimal:
                            vectors[j] = readDecimalField(fieldVector);
                            break;
                        case Bool:
                            vectors[j] = readBitField(fieldVector);
                            break;
                        case Date:
                            vectors[j] = readDateField(fieldVector);
                            break;
                        case Time:
                            vectors[j] = readTimeField(fieldVector);
                            break;
                        case Timestamp:
                            vectors[j] = readDateTimeField(fieldVector);
                            break;
                        case Binary:
                        case FixedSizeBinary:
                            vectors[j] = readByteArrayField(fieldVector);
                            break;
                        case Utf8:
                            vectors[j] = readStringField(fieldVector);
                            break;
                        default: throw new UnsupportedOperationException("Unsupported column type: " + fieldVector.getMinorType());
                    }
                }

                DataFrame frame = DataFrame.of(vectors);
                frames.add(frame);
                size += frames.size();
            }

            if (frames.isEmpty()) {
                throw new IllegalStateException("No record batch");
            } else if (frames.size() == 1) {
                return frames.get(0);
            } else {
                DataFrame df = frames.get(0);
                return df.union(frames.subList(1, frames.size()).toArray(new DataFrame[frames.size() - 1]));
            }
        }
    }

    /** Writes the DataFrame to a file. */
    public void write(DataFrame df, Path path) throws IOException {
        if (allocator == null) {
            allocate(Long.MAX_VALUE);
        }

        Schema schema = toArrowSchema(df.schema());
        /**
         * When a field is dictionary encoded, the values are represented
         * by an array of Int32 representing the index of the value in the
         * dictionary. The Dictionary is received as one or more
         * DictionaryBatches with the id referenced by a dictionary attribute
         * defined in the metadata (Message.fbs) in the Field table.
         * The dictionary has the same layout as the type of the field would
         * dictate. Each entry in the dictionary can be accessed by its index
         * in the DictionaryBatches. When a Schema references a Dictionary id,
         * it must send at least one DictionaryBatch for this id.
         */
        DictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider();
        try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
             OutputStream output = Files.newOutputStream(path);
             ArrowStreamWriter writer = new ArrowStreamWriter(root, provider, output)) {

            writer.start();
            final int size = df.size();
            for (int from = 0, entries = size; from < size; from += batch) {
                int count = Math.min(batch, entries - from);
                // set the batch row count
                root.setRowCount(count);

                for (Field field : root.getSchema().getFields()) {
                    FieldVector vector = root.getVector(field.getName());
                    DataType type = df.schema().field(field.getName()).type;
                    switch (type.id()) {
                        case Integer:
                            writeIntField(df, vector, from, count);
                            break;
                        case Long:
                            writeLongField(df, vector, from, count);
                            break;
                        case Double:
                            writeDoubleField(df, vector, from, count);
                            break;
                        case Float:
                            writeFloatField(df, vector, from, count);
                            break;
                        case Boolean:
                            writeBooleanField(df, vector, from, count);
                            break;
                        case Byte:
                            writeByteField(df, vector, from, count);
                            break;
                        case Short:
                            writeShortField(df, vector, from, count);
                            break;
                        case Char:
                            writeCharField(df, vector, from, count);
                            break;
                        case String:
                            writeStringField(df, vector, from, count);
                            break;
                        case Date:
                            writeDateField(df, vector, from, count);
                            break;
                        case Time:
                            writeTimeField(df, vector, from, count);
                            break;
                        case DateTime:
                            writeDateTimeField(df, vector, from, count);
                            break;
                        case Object: {
                            Class clazz = ((ObjectType) type).getObjectClass();
                            if (clazz == Integer.class) {
                                writeIntObjectField(df, vector, from, count);
                            } else if (clazz == Long.class) {
                                writeLongObjectField(df, vector, from, count);
                            } else if (clazz == Double.class) {
                                writeDoubleObjectField(df, vector, from, count);
                            } else if (clazz == Float.class) {
                                writeFloatObjectField(df, vector, from, count);
                            } else if (clazz == Boolean.class) {
                                writeBooleanObjectField(df, vector, from, count);
                            } else if (clazz == Byte.class) {
                                writeByteObjectField(df, vector, from, count);
                            } else if (clazz == Short.class) {
                                writeShortObjectField(df, vector, from, count);
                            } else if (clazz == Character.class) {
                                writeCharObjectField(df, vector, from, count);
                            } else if (clazz == BigDecimal.class) {
                                writeDecimalField(df, vector, from, count);
                            } else if (clazz == String.class) {
                                writeStringField(df, vector, from, count);
                            } else if (clazz == LocalDate.class) {
                                writeDateField(df, vector, from, count);
                            } else if (clazz == LocalTime.class) {
                                writeTimeField(df, vector, from, count);
                            } else if (clazz == LocalDateTime.class) {
                                writeDateTimeField(df, vector, from, count);
                            } else {
                                throw new UnsupportedOperationException("Unsupported type: " + type);
                            }
                            break;
                        }
                        case Array: {
                            DataType etype = ((ArrayType) type).getComponentType();
                            switch (etype.id()) {
                                case Byte:
                                    writeByteArrayField(df, vector, from, count);
                                    break;
                                default:
                                    throw new UnsupportedOperationException("Unsupported type: " + type);
                            }
                            break;
                        }

                        default:
                            throw new UnsupportedOperationException("Unsupported type: " + type);
                    }
                }

                writer.writeBatch();
                logger.info("write {} rows", count);
            }
        }
    }

    /** Reads a boolean column. */
    private smile.data.vector.BaseVector readBitField(FieldVector fieldVector) {
        int count = fieldVector.getValueCount();
        BitVector vector = (BitVector) fieldVector;

        if (!fieldVector.getField().isNullable()) {
            boolean[] a = new boolean[count];
            for (int i = 0; i < count; i++) {
                a[i] = vector.get(i) != 0;
            }

            return smile.data.vector.BooleanVector.of(fieldVector.getField().getName(), a);
        } else {
            Boolean[] a = new Boolean[count];
            for (int i = 0; i < count; i++) {
                if (vector.isNull(i))
                    a[i] = null;
                else
                    a[i] = vector.get(i) != 0;
            }

            return smile.data.vector.Vector.of(fieldVector.getField().getName(), Boolean.class, a);
        }
    }

    /** Reads a byte column. */
    private smile.data.vector.BaseVector readByteField(FieldVector fieldVector) {
        int count = fieldVector.getValueCount();
        TinyIntVector vector = (TinyIntVector) fieldVector;

        if (!fieldVector.getField().isNullable()) {
            byte[] a = new byte[count];
            for (int i = 0; i < count; i++) {
                a[i] = vector.get(i);
            }

            return smile.data.vector.ByteVector.of(fieldVector.getField().getName(), a);
        } else {
            Byte[] a = new Byte[count];
            for (int i = 0; i < count; i++) {
                if (vector.isNull(i))
                    a[i] = null;
                else
                    a[i] = vector.get(i);
            }

            return smile.data.vector.Vector.of(fieldVector.getField().getName(), Byte.class, a);
        }
    }

    /** Reads a char column. */
    private smile.data.vector.BaseVector readCharField(FieldVector fieldVector) {
        int count = fieldVector.getValueCount();
        SmallIntVector vector = (SmallIntVector) fieldVector;

        if (!fieldVector.getField().isNullable()) {
            char[] a = new char[count];
            for (int i = 0; i < count; i++) {
                a[i] = (char) vector.get(i);
            }

            return smile.data.vector.CharVector.of(fieldVector.getField().getName(), a);
        } else {
            Character[] a = new Character[count];
            for (int i = 0; i < count; i++) {
                if (vector.isNull(i))
                    a[i] = null;
                else
                    a[i] = (char) vector.get(i);
            }

            return smile.data.vector.Vector.of(fieldVector.getField().getName(), Character.class, a);
        }
    }

    /** Reads a short column. */
    private smile.data.vector.BaseVector readShortField(FieldVector fieldVector) {
        int count = fieldVector.getValueCount();
        SmallIntVector vector = (SmallIntVector) fieldVector;

        if (!fieldVector.getField().isNullable()) {
            short[] a = new short[count];
            for (int i = 0; i < count; i++) {
                a[i] = vector.get(i);
            }

            return smile.data.vector.ShortVector.of(fieldVector.getField().getName(), a);
        } else {
            Short[] a = new Short[count];
            for (int i = 0; i < count; i++) {
                if (vector.isNull(i))
                    a[i] = null;
                else
                    a[i] = vector.get(i);
            }

            return smile.data.vector.Vector.of(fieldVector.getField().getName(), Short.class, a);
        }
    }

    /** Reads an int column. */
    private smile.data.vector.BaseVector readIntField(FieldVector fieldVector) {
        int count = fieldVector.getValueCount();
        IntVector vector = (IntVector) fieldVector;

        if (!fieldVector.getField().isNullable()) {
            int[] a = new int[count];
            for (int i = 0; i < count; i++) {
                a[i] = vector.get(i);
            }

            return smile.data.vector.IntVector.of(fieldVector.getField().getName(), a);
        } else {
            Integer[] a = new Integer[count];
            for (int i = 0; i < count; i++) {
                if (vector.isNull(i))
                    a[i] = null;
                else
                    a[i] = vector.get(i);
            }

            return smile.data.vector.Vector.of(fieldVector.getField().getName(), Integer.class, a);
        }
    }

    /** Reads a long column. */
    private smile.data.vector.BaseVector readLongField(FieldVector fieldVector) {
        int count = fieldVector.getValueCount();
        BigIntVector vector = (BigIntVector) fieldVector;

        if (!fieldVector.getField().isNullable()) {
            long[] a = new long[count];
            for (int i = 0; i < count; i++) {
                a[i] = vector.get(i);
            }

            return smile.data.vector.LongVector.of(fieldVector.getField().getName(), a);
        } else {
            Long[] a = new Long[count];
            for (int i = 0; i < count; i++) {
                if (vector.isNull(i))
                    a[i] = null;
                else
                    a[i] = vector.get(i);
            }

            return smile.data.vector.Vector.of(fieldVector.getField().getName(), Long.class, a);
        }
    }

    /** Reads a float column. */
    private smile.data.vector.BaseVector readFloatField(FieldVector fieldVector) {
        int count = fieldVector.getValueCount();
        Float4Vector vector = (Float4Vector) fieldVector;

        if (!fieldVector.getField().isNullable()) {
            float[] a = new float[count];
            for (int i = 0; i < count; i++) {
                a[i] = vector.get(i);
            }

            return smile.data.vector.FloatVector.of(fieldVector.getField().getName(), a);
        } else {
            Float[] a = new Float[count];
            for (int i = 0; i < count; i++) {
                if (vector.isNull(i))
                    a[i] = null;
                else
                    a[i] = vector.get(i);
            }

            return smile.data.vector.Vector.of(fieldVector.getField().getName(), Float.class, a);
        }
    }

    /** Reads a double column. */
    private smile.data.vector.BaseVector readDoubleField(FieldVector fieldVector) {
        int count = fieldVector.getValueCount();
        Float8Vector vector = (Float8Vector) fieldVector;

        if (!fieldVector.getField().isNullable()) {
            double[] a = new double[count];
            for (int i = 0; i < count; i++) {
                a[i] = vector.get(i);
            }

            return smile.data.vector.DoubleVector.of(fieldVector.getField().getName(), a);
        } else {
            Double[] a = new Double[count];
            for (int i = 0; i < count; i++) {
                if (vector.isNull(i))
                    a[i] = null;
                else
                    a[i] = vector.get(i);
            }

            return smile.data.vector.Vector.of(fieldVector.getField().getName(), Double.class, a);
        }
    }

    /** Reads a decimal column. */
    private smile.data.vector.BaseVector readDecimalField(FieldVector fieldVector) {
        int count = fieldVector.getValueCount();
        BigDecimal[] a = new BigDecimal[count];
        DecimalVector vector = (DecimalVector) fieldVector;
        for (int i = 0; i < count; i++) {
            a[i] = vector.isNull(i) ? null : vector.getObject(i);
        }

        return smile.data.vector.Vector.of(fieldVector.getField().getName(), DataTypes.DecimalType, a);
    }

    /** Reads a date column. */
    private smile.data.vector.BaseVector readDateField(FieldVector fieldVector) {
        int count = fieldVector.getValueCount();
        LocalDate[] a = new LocalDate[count];
        ZoneOffset zone = OffsetDateTime.now().getOffset();
        if (fieldVector instanceof DateDayVector) {
            DateDayVector vector = (DateDayVector) fieldVector;
            for (int i = 0; i < count; i++) {
                a[i] = vector.isNull(i) ? null : LocalDate.ofEpochDay(vector.get(i));
            }
        } else if (fieldVector instanceof DateMilliVector) {
            DateMilliVector vector = (DateMilliVector) fieldVector;
            for (int i = 0; i < count; i++) {
                a[i] = vector.isNull(i) ? null : LocalDateTime.ofInstant(Instant.ofEpochMilli(vector.get(i)), zone).toLocalDate();
            }
        }

        return smile.data.vector.Vector.of(fieldVector.getField().getName(), DataTypes.DateType, a);
    }

    /** Reads a time column. */
    private smile.data.vector.BaseVector readTimeField(FieldVector fieldVector) {
        int count = fieldVector.getValueCount();
        LocalTime[] a = new LocalTime[count];
        if (fieldVector instanceof TimeNanoVector) {
            TimeNanoVector vector = (TimeNanoVector) fieldVector;
            for (int i = 0; i < count; i++) {
                a[i] = vector.isNull(i) ? null : LocalTime.ofNanoOfDay(vector.get(i));
            }
        } else if (fieldVector instanceof TimeMilliVector) {
            TimeMilliVector vector = (TimeMilliVector) fieldVector;
            for (int i = 0; i < count; i++) {
                a[i] = vector.isNull(i) ? null : LocalTime.ofNanoOfDay(vector.get(i) * 1000000);
            }
        } else if (fieldVector instanceof TimeMicroVector) {
            TimeMicroVector vector = (TimeMicroVector) fieldVector;
            for (int i = 0; i < count; i++) {
                a[i] = vector.isNull(i) ? null : LocalTime.ofNanoOfDay(vector.get(i) * 1000);
            }
        } else if (fieldVector instanceof TimeSecVector) {
            TimeSecVector vector = (TimeSecVector) fieldVector;
            for (int i = 0; i < count; i++) {
                a[i] = vector.isNull(i) ? null : LocalTime.ofSecondOfDay(vector.get(i));
            }
        }
        return smile.data.vector.Vector.of(fieldVector.getField().getName(), DataTypes.TimeType, a);
    }

    /** Reads a DateTime column. */
    private smile.data.vector.BaseVector readDateTimeField(FieldVector fieldVector) {
        int count = fieldVector.getValueCount();
        LocalDateTime[] a = new LocalDateTime[count];
        TimeStampVector vector = (TimeStampVector) fieldVector;
        String timezone = ((ArrowType.Timestamp) fieldVector.getField().getType()).getTimezone();
        ZoneOffset zone = timezone == null ? OffsetDateTime.now().getOffset() : ZoneOffset.of(timezone);
        if (fieldVector instanceof TimeStampMilliVector) {
            for (int i = 0; i < count; i++) {
                a[i] = vector.isNull(i) ? null : LocalDateTime.ofInstant(Instant.ofEpochMilli(vector.get(i)), zone);
            }
        } else if (fieldVector instanceof TimeStampNanoVector) {
            for (int i = 0; i < count; i++) {
                a[i] = vector.isNull(i) ? null : LocalDateTime.ofInstant(Instant.ofEpochMilli(vector.get(i)/1000000), zone);
            }
        } else if (fieldVector instanceof TimeStampMicroVector) {
            for (int i = 0; i < count; i++) {
                a[i] = vector.isNull(i) ? null : LocalDateTime.ofInstant(Instant.ofEpochMilli(vector.get(i)/1000), zone);
            }
        } else if (fieldVector instanceof TimeStampSecVector) {
            for (int i = 0; i < count; i++) {
                a[i] = vector.isNull(i) ? null : LocalDateTime.ofEpochSecond(vector.get(i), 0, zone);
            }
        }

        return smile.data.vector.Vector.of(fieldVector.getField().getName(), DataTypes.DateTimeType, a);
    }

    /** Reads a byte[] column. */
    private smile.data.vector.BaseVector readByteArrayField(FieldVector fieldVector) {
        int count = fieldVector.getValueCount();
        byte[][] a = new byte[count][];
        if (fieldVector instanceof VarBinaryVector) {
            VarBinaryVector vector = (VarBinaryVector) fieldVector;
            for (int i = 0; i < count; i++) {
                if (vector.isNull(i))
                    a[i] = null;
                else
                    a[i] = vector.get(i);
            }
        } else if (fieldVector instanceof FixedSizeBinaryVector){
            FixedSizeBinaryVector vector = (FixedSizeBinaryVector) fieldVector;
            for (int i = 0; i < count; i++) {
                if (vector.isNull(i))
                    a[i] = null;
                else
                    a[i] = vector.get(i);
            }
        } else {
            throw new UnsupportedOperationException("Unsupported binary vector: " + fieldVector);
        }

        return smile.data.vector.Vector.of(fieldVector.getField().getName(), DataTypes.ByteArrayType, a);
    }

    /** Reads a String column. */
    private smile.data.vector.BaseVector readStringField(FieldVector fieldVector) {
        int count = fieldVector.getValueCount();
        VarCharVector vector = (VarCharVector) fieldVector;
        String[] a = new String[count];
        for (int i = 0; i < count; i++) {
            if (vector.isNull(i))
                a[i] = null;
            else
                a[i] = new String(vector.get(i));
        }

        return smile.data.vector.Vector.of(fieldVector.getField().getName(), DataTypes.StringType, a);
    }

    /** Writes an int column. */
    private void writeIntField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        IntVector vector = (IntVector) fieldVector;
        smile.data.vector.IntVector column = df.intVector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            vector.set(i, column.getInt(j));
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a nullable int column. */
    private void writeIntObjectField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        IntVector vector = (IntVector) fieldVector;
        smile.data.vector.Vector column = df.vector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            Integer x = column.get(i);
            if (x == null) {
                vector.setNull(i);
            } else {
                vector.setIndexDefined(i);
                vector.setSafe(i, x);
            }
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a boolean column. */
    private void writeBooleanField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        BitVector vector = (BitVector) fieldVector;
        smile.data.vector.BooleanVector column = df.booleanVector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            vector.set(i, column.getInt(j));
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a nullable boolean column. */
    private void writeBooleanObjectField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        BitVector vector = (BitVector) fieldVector;
        smile.data.vector.Vector column = df.vector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            Boolean x = column.get(i);
            if (x == null) {
                vector.setNull(i);
            } else {
                vector.setIndexDefined(i);
                vector.setSafe(i, x ? 1 : 0);
            }
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a byte column. */
    private void writeCharField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        UInt2Vector vector = (UInt2Vector) fieldVector;
        smile.data.vector.CharVector column = df.charVector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            vector.set(i, column.getChar(j));
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a nullable char column. */
    private void writeCharObjectField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        UInt2Vector vector = (UInt2Vector) fieldVector;
        smile.data.vector.Vector column = df.vector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            Character x = column.get(i);
            if (x == null) {
                vector.setNull(i);
            } else {
                vector.setIndexDefined(i);
                vector.setSafe(i, x);
            }
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a byte column. */
    private void writeByteField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        TinyIntVector vector = (TinyIntVector) fieldVector;
        smile.data.vector.ByteVector column = df.byteVector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            vector.set(i, column.getByte(j));
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a nullable byte column. */
    private void writeByteObjectField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        TinyIntVector vector = (TinyIntVector) fieldVector;
        smile.data.vector.Vector column = df.vector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            Byte x = column.get(i);
            if (x == null) {
                vector.setNull(i);
            } else {
                vector.setIndexDefined(i);
                vector.setSafe(i, x);
            }
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a short column. */
    private void writeShortField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        SmallIntVector vector = (SmallIntVector) fieldVector;
        smile.data.vector.ShortVector column = df.shortVector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            vector.set(i, column.getShort(j));
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a nullable short column. */
    private void writeShortObjectField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        SmallIntVector vector = (SmallIntVector) fieldVector;
        smile.data.vector.Vector column = df.vector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            Short x = column.get(i);
            if (x == null) {
                vector.setNull(i);
            } else {
                vector.setIndexDefined(i);
                vector.setSafe(i, x);
            }
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a long column. */
    private void writeLongField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        BigIntVector vector = (BigIntVector) fieldVector;
        smile.data.vector.LongVector column = df.longVector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            vector.set(i, column.getLong(j));
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a nullable long column. */
    private void writeLongObjectField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        BigIntVector vector = (BigIntVector) fieldVector;
        smile.data.vector.Vector column = df.vector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            Long x = column.get(i);
            if (x == null) {
                vector.setNull(i);
            } else {
                vector.setIndexDefined(i);
                vector.setSafe(i, x);
            }
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a float column. */
    private void writeFloatField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        Float4Vector vector  = (Float4Vector) fieldVector;
        smile.data.vector.FloatVector column = df.floatVector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            vector.set(i, column.getFloat(j));
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a nullable float column. */
    private void writeFloatObjectField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        Float4Vector vector  = (Float4Vector) fieldVector;
        smile.data.vector.Vector column = df.vector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            Float x = column.get(i);
            if (x == null) {
                vector.setNull(i);
            } else {
                vector.setIndexDefined(i);
                vector.setSafe(i, x);
            }
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a double column. */
    private void writeDoubleField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        Float8Vector vector  = (Float8Vector) fieldVector;
        smile.data.vector.DoubleVector column = df.doubleVector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            vector.set(i, column.getDouble(j));
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a nullable double column. */
    private void writeDoubleObjectField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        Float8Vector vector  = (Float8Vector) fieldVector;
        smile.data.vector.Vector column = df.vector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            Double x = column.get(i);
            if (x == null) {
                vector.setNull(i);
            } else {
                vector.setIndexDefined(i);
                vector.setSafe(i, x);
            }
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a string column. */
    private void writeStringField(DataFrame df, FieldVector fieldVector, int from, int count) throws UnsupportedEncodingException {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        VarCharVector vector = (VarCharVector) fieldVector;
        smile.data.vector.StringVector column = df.stringVector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            String x = column.get(j);
            if (x == null) {
                vector.setNull(i);
            } else {
                vector.setIndexDefined(i);
                vector.setSafe(i, x.getBytes("UTF-8"));
            }
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a decimal column. */
    private void writeDecimalField(DataFrame df, FieldVector fieldVector, int from, int count) throws UnsupportedEncodingException {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        DecimalVector vector = (DecimalVector) fieldVector;
        smile.data.vector.Vector column = df.vector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            BigDecimal x = column.get(j);
            if (x == null) {
                vector.setNull(i);
            } else {
                vector.setIndexDefined(i);
                vector.setSafe(i, x);
            }
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a date column. */
    private void writeDateField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        DateDayVector vector = (DateDayVector) fieldVector;
        smile.data.vector.Vector column = df.vector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            LocalDate x = column.get(j);
            if (x == null) {
                vector.setNull(i);
            } else {
                vector.setIndexDefined(i);
                vector.setSafe(i, (int) x.toEpochDay());
            }
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a time column. */
    private void writeTimeField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        TimeNanoVector vector = (TimeNanoVector) fieldVector;
        smile.data.vector.Vector column = df.vector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            LocalTime x = column.get(j);
            if (x == null) {
                vector.setNull(i);
            } else {
                vector.setIndexDefined(i);
                vector.setSafe(i, x.toNanoOfDay());
            }
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a datetime column. */
    private void writeDateTimeField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        TimeStampMilliTZVector vector = (TimeStampMilliTZVector) fieldVector;
        smile.data.vector.Vector column = df.vector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            LocalDateTime x = column.get(j);
            if (x == null) {
                vector.setNull(i);
            } else {
                vector.setIndexDefined(i);
                vector.setSafe(i, x.toInstant(OffsetDateTime.now().getOffset()).toEpochMilli());
            }
        }

        fieldVector.setValueCount(count);
    }

    /** Writes a byte array column. */
    private void writeByteArrayField(DataFrame df, FieldVector fieldVector, int from, int count) {
        fieldVector.setInitialCapacity(count);
        fieldVector.allocateNew();

        VarBinaryVector vector = (VarBinaryVector) fieldVector;
        smile.data.vector.Vector column = df.vector(fieldVector.getField().getName());
        for (int i = 0, j = from; i < count; i++, j++) {
            byte[] bytes = column.get(j);
            if (bytes == null) {
                vector.setNull(i);
            } else {
                vector.setIndexDefined(i);
                vector.setSafe(i, bytes);
            }
        }

        fieldVector.setValueCount(count);
    }

    /** Converts a smile schema to arrow schema. */
    private Schema toArrowSchema(StructType schema) {
        List fields = new ArrayList<>();
        for (StructField field : schema.fields()) {
            fields.add(toArrowField(field));
        }

        return new Schema(fields, null);
    }

    /** Converts an arrow schema to smile schema. */
    private StructType toSmileSchema(Schema schema) {
        List fields = new ArrayList<>();
        for (Field field : schema.getFields()) {
            fields.add(toSmileField(field));
        }

        return DataTypes.struct(fields);
    }

    /** Converts a smile struct field to arrow field. */
    private Field toArrowField(StructField field) {
        switch (field.type.id()) {
            case Integer:
                return new Field(field.name, new FieldType(false, new ArrowType.Int(32, true), null), null);
            case Long:
                return new Field(field.name, new FieldType(false, new ArrowType.Int(64, true), null), null);
            case Double:
                return new Field(field.name, new FieldType(false, new ArrowType.FloatingPoint(DOUBLE), null), null);
            case Float:
                return new Field(field.name, new FieldType(false, new ArrowType.FloatingPoint(SINGLE), null), null);
            case Boolean:
                return new Field(field.name, new FieldType(false, new ArrowType.Bool(), null), null);
            case Byte:
                return new Field(field.name, new FieldType(false, new ArrowType.Int(8, true), null), null);
            case Short:
                return new Field(field.name, new FieldType(false, new ArrowType.Int(16, true), null), null);
            case Char:
                return new Field(field.name, new FieldType(false, new ArrowType.Int(16, false), null), null);
            case Decimal:
                return new Field(field.name, FieldType.nullable(new ArrowType.Decimal(28, 10)), null);
            case String:
                return new Field(field.name, FieldType.nullable(new ArrowType.Utf8()), null);
            case Date:
                return new Field(field.name, FieldType.nullable(new ArrowType.Date(DateUnit.DAY)), null);
            case Time:
                return new Field(field.name, FieldType.nullable(new ArrowType.Time(TimeUnit.MILLISECOND, 32)), null);
            case DateTime:
                return new Field(field.name, FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, java.time.ZoneOffset.UTC.getId())), null);
            case Object: {
                Class clazz = ((ObjectType) field.type).getObjectClass();
                if (clazz == Integer.class) {
                    return new Field(field.name, FieldType.nullable(new ArrowType.Int(32, true)), null);
                } else if (clazz == Long.class) {
                    return new Field(field.name, FieldType.nullable(new ArrowType.Int(64, true)), null);
                } else if (clazz == Double.class) {
                    return new Field(field.name, FieldType.nullable(new ArrowType.FloatingPoint(DOUBLE)), null);
                } else if (clazz == Float.class) {
                    return new Field(field.name, FieldType.nullable(new ArrowType.FloatingPoint(SINGLE)), null);
                } else if (clazz == Boolean.class) {
                    return new Field(field.name, FieldType.nullable(new ArrowType.Bool()), null);
                } else if (clazz == Byte.class) {
                    return new Field(field.name, FieldType.nullable(new ArrowType.Int(8, true)), null);
                } else if (clazz == Short.class) {
                    return new Field(field.name, FieldType.nullable(new ArrowType.Int(16, true)), null);
                } else if (clazz == Character.class) {
                    return new Field(field.name, FieldType.nullable(new ArrowType.Int(16, false)), null);
                }
                break;
            }
            case Array: {
                DataType etype = ((ArrayType) field.type).getComponentType();
                switch (etype.id()) {
                    case Integer:
                        return new Field(field.name,
                                new FieldType(false, new ArrowType.List(), null),
                                // children type
                                Arrays.asList(new Field(null, new FieldType(false, new ArrowType.Int(32, true), null), null))
                        );
                    case Long:
                        return new Field(field.name,
                                new FieldType(false, new ArrowType.List(), null),
                                // children type
                                Arrays.asList(new Field(null, new FieldType(false, new ArrowType.Int(64, true), null), null))
                        );
                    case Double:
                        return new Field(field.name,
                                new FieldType(false, new ArrowType.List(), null),
                                // children type
                                Arrays.asList(new Field(null, new FieldType(false, new ArrowType.FloatingPoint(DOUBLE), null), null))
                        );
                    case Float:
                        return new Field(field.name,
                                new FieldType(false, new ArrowType.List(), null),
                                // children type
                                Arrays.asList(new Field(null, new FieldType(false, new ArrowType.FloatingPoint(SINGLE), null), null))
                        );
                    case Boolean:
                        return new Field(field.name,
                                new FieldType(false, new ArrowType.List(), null),
                                // children type
                                Arrays.asList(new Field(null, new FieldType(false, new ArrowType.Bool(), null), null))
                        );
                    case Byte:
                        return new Field(field.name, FieldType.nullable(new ArrowType.Binary()), null);
                    case Short:
                        return new Field(field.name,
                                new FieldType(false, new ArrowType.List(), null),
                                // children type
                                Arrays.asList(new Field(null, new FieldType(false, new ArrowType.Int(16, true), null), null))
                        );
                    case Char:
                        return new Field(field.name, FieldType.nullable(new ArrowType.Utf8()), null);
                }
                break;
            }
            case Struct: {
                StructType children = (StructType) field.type;
                return new Field(field.name,
                        new FieldType(false, new ArrowType.Struct(), null),
                        // children type
                        Arrays.stream(children.fields()).map(this::toArrowField).collect(Collectors.toList())
                );
            }
        }

        throw new UnsupportedOperationException("Unsupported smile to arrow type conversion: " + field.type);
    }

    /** Converts an arrow field to smile struct field. */
    private StructField toSmileField(Field field) {
        String name = field.getName();
        ArrowType type = field.getType();
        boolean nullable = field.isNullable();
        switch (type.getTypeID()) {
            case Int:
                ArrowType.Int itype = (ArrowType.Int) type;
                int bitWidth = itype.getBitWidth();
                switch (bitWidth) {
                    case  8: return new StructField(name, nullable ? DataTypes.ByteObjectType : DataTypes.ByteType);
                    case 16:
                        if (itype.getIsSigned())
                            return new StructField(name, nullable ? DataTypes.ShortObjectType : DataTypes.ShortType);
                        else
                            return new StructField(name, nullable ? DataTypes.CharObjectType : DataTypes.CharType);
                    case 32: return new StructField(name, nullable ? DataTypes.IntegerObjectType : DataTypes.IntegerType);
                    case 64: return new StructField(name, nullable ? DataTypes.LongObjectType : DataTypes.LongType);
                    default: throw new UnsupportedOperationException("Unsupported integer bit width: " + bitWidth);
                }

            case FloatingPoint:
                FloatingPointPrecision precision = ((ArrowType.FloatingPoint) type).getPrecision();
                switch (precision) {
                    case DOUBLE: return new StructField(name, nullable ? DataTypes.DoubleObjectType : DataTypes.DoubleType);
                    case SINGLE: return new StructField(name, nullable ? DataTypes.FloatObjectType : DataTypes.FloatType);
                    case HALF: throw new UnsupportedOperationException("Unsupported float precision: " + precision);
                }

            case Bool:
                return new StructField(name, nullable ? DataTypes.BooleanObjectType : DataTypes.BooleanType);

            case Decimal:
                return new StructField(name, DataTypes.DecimalType);

            case Utf8:
                return new StructField(name, DataTypes.StringType);

            case Date:
                return new StructField(name, DataTypes.DateType);

            case Time:
                return new StructField(name, DataTypes.TimeType);

            case Timestamp:
                return new StructField(name, DataTypes.DateTimeType);

            case Binary:
            case FixedSizeBinary:
                return new StructField(name, DataTypes.ByteArrayType);

            case List:
            case FixedSizeList:
                List child = field.getChildren();
                if (child.size() != 1) {
                    throw new IllegalStateException(String.format("List type has %d child fields.", child.size()));
                }

                return new StructField(name, DataTypes.array(toSmileField(child.get(0)).type));

            case Struct:
                List children = field.getChildren().stream().map(this::toSmileField).collect(Collectors.toList());
                return new StructField(name, DataTypes.struct(children));

            default:
                throw new UnsupportedOperationException("Unsupported arrow to smile type conversion: " + type);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy