All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.datavec.api.transform.schema.SequenceSchema Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*-
 *  * Copyright 2016 Skymind, Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 */

package org.datavec.api.transform.schema;

import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.writable.*;
import org.nd4j.shade.jackson.annotation.JsonProperty;

import java.util.List;

/**
 * A SequenceSchema is a {@link Schema} for sequential data.
 *
 * @author Alex Black
 */
@EqualsAndHashCode(callSuper = true)
@Data
public class SequenceSchema extends Schema {
    private final Integer minSequenceLength;
    private final Integer maxSequenceLength;

    public SequenceSchema(List columnMetaData) {
        this(columnMetaData, null, null);
    }

    public SequenceSchema(@JsonProperty("columns") List columnMetaData,
                    @JsonProperty("minSequenceLength") Integer minSequenceLength,
                    @JsonProperty("maxSequenceLength") Integer maxSequenceLength) {
        super(columnMetaData);
        this.minSequenceLength = minSequenceLength;
        this.maxSequenceLength = maxSequenceLength;
    }

    private SequenceSchema(Builder builder) {
        super(builder);
        this.minSequenceLength = builder.minSequenceLength;
        this.maxSequenceLength = builder.maxSequenceLength;
    }

    @Override
    public SequenceSchema newSchema(List columnMetaData) {
        return new SequenceSchema(columnMetaData, minSequenceLength, maxSequenceLength);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        int nCol = numColumns();

        int maxNameLength = 0;
        for (String s : getColumnNames()) {
            maxNameLength = Math.max(maxNameLength, s.length());
        }

        //Header:
        sb.append("SequenceSchema(");

        if (minSequenceLength != null)
            sb.append("minSequenceLength=").append(minSequenceLength);
        if (maxSequenceLength != null) {
            if (minSequenceLength != null)
                sb.append(",");
            sb.append("maxSequenceLength=").append(maxSequenceLength);
        }

        sb.append(")\n");
        sb.append(String.format("%-6s", "idx")).append(String.format("%-" + (maxNameLength + 8) + "s", "name"))
                        .append(String.format("%-15s", "type")).append("meta data").append("\n");

        for (int i = 0; i < nCol; i++) {
            String colName = getName(i);
            ColumnType type = getType(i);
            ColumnMetaData meta = getMetaData(i);
            String paddedName = String.format("%-" + (maxNameLength + 8) + "s", "\"" + colName + "\"");
            sb.append(String.format("%-6d", i)).append(paddedName).append(String.format("%-15s", type)).append(meta)
                            .append("\n");
        }

        return sb.toString();
    }

    public static class Builder extends Schema.Builder {

        private Integer minSequenceLength;
        private Integer maxSequenceLength;

        public Builder minSequenceLength(int minSequenceLength) {
            this.minSequenceLength = minSequenceLength;
            return this;
        }

        public Builder maxSequenceLength(int maxSequenceLength) {
            this.maxSequenceLength = maxSequenceLength;
            return this;
        }


        @Override
        public SequenceSchema build() {
            return new SequenceSchema(this);
        }


    }


    /**
     * Infers a sequence schema based
     * on the record
     * @param record the record to infer the schema based on
     * @return the inferred sequence schema
     *
     */
    public static SequenceSchema inferSequenceMulti(List>> record) {
        SequenceSchema.Builder builder = new SequenceSchema.Builder();
        int minSequenceLength = record.get(0).size();
        int maxSequenceLength = record.get(0).size();
        for (int i = 0; i < record.size(); i++) {
            if (record.get(i) instanceof DoubleWritable)
                builder.addColumnDouble(String.valueOf(i));
            else if (record.get(i) instanceof IntWritable)
                builder.addColumnInteger(String.valueOf(i));
            else if (record.get(i) instanceof LongWritable)
                builder.addColumnLong(String.valueOf(i));
            else if (record.get(i) instanceof FloatWritable)
                builder.addColumnFloat(String.valueOf(i));

            else
                throw new IllegalStateException("Illegal writable for infering schema of type "
                                + record.get(i).getClass().toString() + " with record " + record.get(0));
            builder.minSequenceLength(Math.min(record.get(i).size(), minSequenceLength));
            builder.maxSequenceLength(Math.max(record.get(i).size(), maxSequenceLength));
        }


        return builder.build();
    }

    /**
     * Infers a sequence schema based
     * on the record
     * @param record the record to infer the schema based on
     * @return the inferred sequence schema
     *
     */
    public static SequenceSchema inferSequence(List> record) {
        SequenceSchema.Builder builder = new SequenceSchema.Builder();
        for (int i = 0; i < record.size(); i++) {
            if (record.get(i) instanceof DoubleWritable)
                builder.addColumnDouble(String.valueOf(i));
            else if (record.get(i) instanceof IntWritable)
                builder.addColumnInteger(String.valueOf(i));
            else if (record.get(i) instanceof LongWritable)
                builder.addColumnLong(String.valueOf(i));
            else if (record.get(i) instanceof FloatWritable)
                builder.addColumnFloat(String.valueOf(i));

            else
                throw new IllegalStateException(
                                "Illegal writable for infering schema of type " + record.get(i).getClass().toString());
        }

        builder.minSequenceLength(record.size());
        builder.maxSequenceLength(record.size());
        return builder.build();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy