All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.gengoai.apollo.ml.model.topic.LSA Maven / Gradle / Ivy

There is a newer version: 2.1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.gengoai.apollo.ml.model.topic;

import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.model.Params;
import com.gengoai.collection.counter.Counter;
import com.gengoai.collection.counter.Counters;
import com.gengoai.stream.spark.SparkStream;
import lombok.NonNull;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.distributed.RowMatrix;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Stream;

import static com.gengoai.apollo.math.linalg.SparkLinearAlgebra.sparkSVD;
import static com.gengoai.apollo.math.linalg.SparkLinearAlgebra.toMatrix;
import static com.gengoai.apollo.ml.observation.VariableCollection.mergeVariableSpace;
import static com.gengoai.function.Functional.with;

/**
 * 

Distributed version of Latent Semantic * Analysis using Apache Spark. Documents are represented by examples and words are by features in the Example.

* * @author David B. Bracewell */ public class LSA extends BaseVectorTopicModel { private static final long serialVersionUID = 1L; private final Parameters parameters; private List topicVectors = new ArrayList<>(); /** * Instantiates a new LSA model with default parameters. */ public LSA() { this(new Parameters()); } /** * Instantiates a new LSA model with the given parameters. * * @param parameters the parameters */ public LSA(@NonNull Parameters parameters) { this.parameters = parameters; } /** * Instantiates a new LSA model with the given model updater. * * @param updater the updater */ public LSA(@NonNull Consumer updater) { this.parameters = with(new Parameters(), updater); } private Stream encode(Datum d) { if(parameters.combineInputs.value()) { return mergeVariableSpace(d.stream(getInputs())) .getVariableSpace() .map(o -> toCountVector(o, parameters.namingPattern.value())); } return d.stream(getInputs()) .map(o -> toCountVector(o, parameters.namingPattern.value())); } @Override public void estimate(@NonNull DataSet dataset) { encoderFit(dataset, getInputs(), parameters.namingPattern.value()); SparkStream stream = new SparkStream(dataset.parallelStream().toDistributedStream() .flatMap(this::encode) .map(o -> new DenseVector(o.toDoubleArray()))) .cache(); RowMatrix mat = new RowMatrix(stream.getRDD().rdd()); //since we have document x word, V is the word x component matrix // U = document x component, E = singular components, V = word x component // Transpose V to get component (topics) x words NDArray topicMatrix = toMatrix(sparkSVD(mat, parameters.K.value()).V().transpose()); for(int i = 0; i < parameters.K.value(); i++) { Counter featureDist = Counters.newCounter(); NDArray dist = NDArrayFactory.ND.columnVector(topicMatrix.getRow(i).toDoubleArray()); dist.forEachSparse((index, v) -> featureDist.set(encoder.decode(index), v)); topics.add(new Topic(i, featureDist)); topicVectors.add(dist); } } @Override public Parameters getFitParameters() { return parameters; } @Override public NDArray getTopicDistribution(String feature) { int i = encoder.encode(feature); if(i == -1) { return NDArrayFactory.ND.rowVector(new double[topics.size()]); } double[] dist = new double[topics.size()]; for(int i1 = 0; i1 < topics.size(); i1++) { dist[i1] = topicVectors.get(i1).get(i); } return NDArrayFactory.ND.rowVector(dist); } @Override protected NDArray inference(NDArray vector) { double[] scores = new double[topics.size()]; for(int i = 0; i < topics.size(); i++) { double score = vector.dot(topicVectors.get(i)); scores[i] = score; } return NDArrayFactory.ND.rowVector(scores); } /** * LSA Fit Parameters. */ public static class Parameters extends TopicModelFitParameters { /** * The number of topics to discover (default 100). */ public final Parameter K = parameter(Params.Clustering.K, 100); } }//END OF LSA




© 2015 - 2025 Weber Informatics LLC | Privacy Policy