All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.cloudera.oryx.app.speed.als.ALSSpeedModelManager Maven / Gradle / Ivy

Go to download

Speed and batch components of machine learning applications built from algorithms implemented in this project

There is a newer version: 2.8.0
Show newest version
/*
 * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */

package com.cloudera.oryx.app.speed.als;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import com.google.common.base.Preconditions;
import com.typesafe.config.Config;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.dmg.pmml.PMML;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

import com.cloudera.oryx.api.speed.AbstractSpeedModelManager;
import com.cloudera.oryx.app.als.ALSUtils;
import com.cloudera.oryx.app.common.fn.MLFunctions;
import com.cloudera.oryx.app.pmml.AppPMMLUtils;
import com.cloudera.oryx.common.lang.RateLimitCheck;
import com.cloudera.oryx.common.text.TextUtils;
import com.cloudera.oryx.common.math.SingularMatrixSolverException;
import com.cloudera.oryx.common.math.Solver;

/**
 * Implementation of {@link com.cloudera.oryx.api.speed.SpeedModelManager} that maintains and
 * updates an ALS model in memory.
 */
public final class ALSSpeedModelManager extends AbstractSpeedModelManager {

  private static final Logger log = LoggerFactory.getLogger(ALSSpeedModelManager.class);

  private ALSSpeedModel model;
  private final boolean noKnownItems;
  private final double minModelLoadFraction;
  private final RateLimitCheck logRateLimit;

  public ALSSpeedModelManager(Config config) {
    noKnownItems = config.getBoolean("oryx.als.no-known-items");
    minModelLoadFraction = config.getDouble("oryx.speed.min-model-load-fraction");
    Preconditions.checkArgument(minModelLoadFraction >= 0.0 && minModelLoadFraction <= 1.0);
    logRateLimit = new RateLimitCheck(1, TimeUnit.MINUTES);
  }

  @Override
  public void consumeKeyMessage(String key, String message, Configuration hadoopConf) throws IOException {
    switch (key) {
      case "UP":
        if (model == null) {
          return; // No model to interpret with yet, so skip it
        }
        // Note that here, the speed layer is actually listening for updates from
        // two sources. First is the batch layer. This is somewhat unusual, because
        // the batch layer usually only makes MODELs. The ALS model is too large
        // to send in one file, so is sent as a skeleton model plus a series of updates.
        // However it is also, neatly, listening for the same updates it produces below
        // in response to new data, and applying them to the in-memory representation.
        // ALS continues to be a somewhat special case here, in that it does benefit from
        // real-time updates to even the speed layer reference model.
        List update = TextUtils.readJSON(message, List.class);
        // Update
        String id = update.get(1).toString();
        float[] vector = TextUtils.convertViaJSON(update.get(2), float[].class);
        switch (update.get(0).toString()) {
          case "X":
            model.setUserVector(id, vector);
            break;
          case "Y":
            model.setItemVector(id, vector);
            break;
          default:
            throw new IllegalArgumentException("Bad message: " + message);
        }
        if (logRateLimit.test()) {
          log.info("{}", model);
        }
        break;

      case "MODEL":
      case "MODEL-REF":
        log.info("Loading new model");
        PMML pmml = AppPMMLUtils.readPMMLFromUpdateKeyMessage(key, message, hadoopConf);
        if (pmml == null) {
          return;
        }

        int features = Integer.parseInt(AppPMMLUtils.getExtensionValue(pmml, "features"));
        boolean implicit = Boolean.parseBoolean(AppPMMLUtils.getExtensionValue(pmml, "implicit"));
        boolean logStrength = Boolean.parseBoolean(AppPMMLUtils.getExtensionValue(pmml, "logStrength"));
        double epsilon = logStrength ?
            Double.parseDouble(AppPMMLUtils.getExtensionValue(pmml, "epsilon")) :
            Double.NaN;

        if (model == null || features != model.getFeatures()) {
          log.warn("No previous model, or # features has changed; creating new one");
          model = new ALSSpeedModel(features, implicit, logStrength, epsilon);
        }

        log.info("Updating model");
        // Remove users/items no longer in the model
        Collection XIDs = new HashSet<>(AppPMMLUtils.getExtensionContent(pmml, "XIDs"));
        Collection YIDs = new HashSet<>(AppPMMLUtils.getExtensionContent(pmml, "YIDs"));
        model.retainRecentAndUserIDs(XIDs);
        model.retainRecentAndItemIDs(YIDs);
        log.info("Model updated: {}", model);
        break;

      default:
        throw new IllegalArgumentException("Bad key: " + key);
    }
  }

  @Override
  public Iterable buildUpdates(JavaPairRDD newData) {
    if (model == null || model.getFractionLoaded() < minModelLoadFraction) {
      return Collections.emptyList();
    }

    // Trigger proactive computation of solvers for later use
    model.precomputeSolvers();

    // Order by timestamp and parse as tuples
    JavaRDD sortedValues =
        newData.values().sortBy(MLFunctions.TO_TIMESTAMP_FN, true, newData.partitions().size());
    JavaPairRDD,Double> tuples = sortedValues.mapToPair(line -> {
      try {
        String[] tokens = MLFunctions.PARSE_FN.call(line);
        String user = tokens[0];
        String item = tokens[1];
        Double strength = tokens[2].isEmpty() ? Double.NaN : Double.valueOf(tokens[2]);
        return new Tuple2<>(new Tuple2<>(user, item), strength);
      } catch (NumberFormatException | ArrayIndexOutOfBoundsException e) {
        log.warn("Bad input: {}", line);
        throw e;
      }
    });

    JavaPairRDD,Double> aggregated;
    if (model.isImplicit()) {
      // See comments in ALSUpdate for explanation of how deletes are handled by this.
      aggregated = tuples.groupByKey().mapValues(MLFunctions.SUM_WITH_NAN);
    } else {
      // For non-implicit, last wins.
      aggregated = tuples.foldByKey(Double.NaN, (current, next) -> next);
    }

    JavaPairRDD,Double> noNaN =
        aggregated.filter(kv -> !Double.isNaN(kv._2()));

    JavaRDD inputRDD;
    if (model.isLogStrength()) {
      double epsilon = model.getEpsilon();
      inputRDD = noNaN.map(tuple -> new UserItemStrength(tuple._1()._1(), tuple._1()._2(),
                                                         (float) Math.log1p(tuple._2() / epsilon)));
    } else {
      inputRDD = noNaN.map(tuple -> new UserItemStrength(tuple._1()._1(), tuple._1()._2(),
                                                         tuple._2().floatValue()));
    }

    Collection input = inputRDD.collect();

    Solver XTXsolver;
    Solver YTYsolver;
    try {
      XTXsolver = model.getXTXSolver();
      YTYsolver = model.getYTYSolver();
    } catch (SingularMatrixSolverException smse) {
      log.info("Not enough data for solver yet ({}); skipping inputs", smse.getMessage());
      return Collections.emptyList();
    }
    if (XTXsolver == null || YTYsolver == null) {
      log.info("No solver available yet for model; skipping inputs");
      return Collections.emptyList();
    }

    return input.parallelStream().flatMap(uis -> {
      String user = uis.getUser();
      String item = uis.getItem();
      double value = uis.getStrength();

      // Xu is the current row u in the X user-feature matrix
      float[] Xu = model.getUserVector(user);
      // Yi is the current row i in the Y item-feature matrix
      float[] Yi = model.getItemVector(item);

      float[] newXu = ALSUtils.computeUpdatedXu(YTYsolver, value, Xu, Yi, model.isImplicit());
      // Similarly for Y vs X
      float[] newYi = ALSUtils.computeUpdatedXu(XTXsolver, value, Yi, Xu, model.isImplicit());

      Collection result = new ArrayList<>(2);
      if (newXu != null) {
        result.add(toUpdateJSON("X", user, newXu, item));
      }
      if (newYi != null) {
        result.add(toUpdateJSON("Y", item, newYi, user));
      }
      return result.stream();
    }).collect(Collectors.toList());
  }

  private String toUpdateJSON(String matrix, String ID, float[] vector, String otherID) {
    List args;
    if (noKnownItems) {
      args = Arrays.asList(matrix, ID, vector);
    } else {
      args = Arrays.asList(matrix, ID, vector, Collections.singletonList(otherID));
    }
    return TextUtils.joinJSON(args);
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy