All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.datavec.nlp.transforms.MultiNlpTransform Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta7
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/


package org.datavec.nlp.transforms;

import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.transform.BaseColumnTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.list.NDArrayList;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty;

import java.util.Collections;
import java.util.List;

/**
 * A multi NLP transform takes in 1 or more bag of words transforms as a pipeline
 * and runs them in sequence.
 * This transform takes in a column name and 1 or more bag of words transforms to run.
 * Lastly, a new column name is specified.
 *
 * @author Adam Gibson
 */
public class MultiNlpTransform extends BaseColumnTransform implements BagOfWordsTransform {

    private BagOfWordsTransform[] transforms;
    private String newColumnName;
    private List vocabWords;

    /**
     *
     * @param columnName
     * @param transforms
     * @param newColumnName
     */
    @JsonCreator
    public MultiNlpTransform(@JsonProperty("columnName") String columnName,
                             @JsonProperty("transforms") BagOfWordsTransform[] transforms,
                             @JsonProperty("newColumnName") String newColumnName) {
        super(columnName);
        this.transforms = transforms;
        this.vocabWords = transforms[0].vocabWords();
        if(transforms.length > 1) {
            for(int i = 1; i < transforms.length; i++) {
                if(!transforms[i].vocabWords().equals(vocabWords)) {
                    throw new IllegalArgumentException("Vocab words not consistent across transforms!");
                }
            }
        }

        this.newColumnName = newColumnName;
    }

    @Override
    public Object mapSequence(Object sequence) {
        NDArrayList ndArrayList = new NDArrayList();
        for(BagOfWordsTransform bagofWordsTransform : transforms) {
            ndArrayList.addAll(new NDArrayList(bagofWordsTransform.transformFromObject((List>) sequence)));
        }

        return ndArrayList.array();
    }

    @Override
    public List> mapSequence(List> sequence) {
     return Collections.singletonList(Collections.singletonList(new NDArrayWritable(transformFrom(sequence))));
    }

    @Override
    public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
        return new NDArrayMetaData(newName,outputShape());
    }

    @Override
    public Writable map(Writable columnWritable) {
        throw new UnsupportedOperationException("Only able to add for time series");
    }

    @Override
    public String toString() {
        return newColumnName;
    }

    @Override
    public Object map(Object input) {
        throw new UnsupportedOperationException("Only able to add for time series");
    }

    @Override
    public long[] outputShape() {
        long[] ret = new long[transforms[0].outputShape().length];
        int validatedRank = transforms[0].outputShape().length;
        for(int i = 1; i < transforms.length; i++) {
            if(transforms[i].outputShape().length != validatedRank) {
                throw new IllegalArgumentException("Inconsistent shape length at transform " + i + " , should have been: " + validatedRank);
            }
        }
        for(int i = 0; i < transforms.length; i++) {
            for(int j = 0; j < validatedRank; j++)
            ret[j] += transforms[i].outputShape()[j];
        }

        return ret;
    }

    @Override
    public List vocabWords() {
        return vocabWords;
    }

    @Override
    public INDArray transformFromObject(List> tokens) {
        NDArrayList ndArrayList = new NDArrayList();
        for(BagOfWordsTransform bagofWordsTransform : transforms) {
            INDArray arr2 = bagofWordsTransform.transformFromObject(tokens);
            arr2 = arr2.reshape(arr2.length());
            NDArrayList newList = new NDArrayList(arr2,(int) arr2.length());
            ndArrayList.addAll(newList);        }

        return ndArrayList.array();
    }

    @Override
    public INDArray transformFrom(List> tokens) {
        NDArrayList ndArrayList = new NDArrayList();
        for(BagOfWordsTransform bagofWordsTransform : transforms) {
            INDArray arr2 = bagofWordsTransform.transformFrom(tokens);
            arr2 = arr2.reshape(arr2.length());
            NDArrayList newList = new NDArrayList(arr2,(int) arr2.length());
            ndArrayList.addAll(newList);
        }

        return ndArrayList.array();
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy