All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lens.ml.impl.LensMLImpl Maven / Gradle / Ivy

There is a newer version: 2.7.1
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.lens.ml.impl;

import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.MediaType;

import org.apache.lens.api.LensConf;
import org.apache.lens.api.LensException;
import org.apache.lens.api.LensSessionHandle;
import org.apache.lens.api.query.LensQuery;
import org.apache.lens.api.query.QueryHandle;
import org.apache.lens.api.query.QueryStatus;
import org.apache.lens.ml.algo.api.MLAlgo;
import org.apache.lens.ml.algo.api.MLDriver;
import org.apache.lens.ml.algo.api.MLModel;
import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
import org.apache.lens.ml.algo.spark.SparkMLDriver;
import org.apache.lens.ml.api.LensML;
import org.apache.lens.ml.api.MLTestReport;
import org.apache.lens.server.api.LensConfConstants;
import org.apache.lens.server.api.session.SessionService;

import org.apache.commons.io.IOUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.session.SessionState;
import org.apache.spark.api.java.JavaSparkContext;

import org.glassfish.jersey.media.multipart.FormDataBodyPart;
import org.glassfish.jersey.media.multipart.FormDataContentDisposition;
import org.glassfish.jersey.media.multipart.FormDataMultiPart;
import org.glassfish.jersey.media.multipart.MultiPartFeature;

/**
 * The Class LensMLImpl.
 */
public class LensMLImpl implements LensML {

  /** The Constant LOG. */
  public static final Log LOG = LogFactory.getLog(LensMLImpl.class);

  /** The drivers. */
  protected List drivers;

  /** The conf. */
  private HiveConf conf;

  /** The spark context. */
  private JavaSparkContext sparkContext;

  /** Check if the predict UDF has been registered for a user */
  private final Map predictUdfStatus;
  /** Background thread to periodically check if we need to clear expire status for a session */
  private ScheduledExecutorService udfStatusExpirySvc;

  /**
   * Instantiates a new lens ml impl.
   *
   * @param conf the conf
   */
  public LensMLImpl(HiveConf conf) {
    this.conf = conf;
    this.predictUdfStatus = new ConcurrentHashMap();
  }

  public HiveConf getConf() {
    return conf;
  }

  /**
   * Use an existing Spark context. Useful in case of
   *
   * @param jsc JavaSparkContext instance
   */
  public void setSparkContext(JavaSparkContext jsc) {
    this.sparkContext = jsc;
  }

  public List getAlgorithms() {
    List algos = new ArrayList();
    for (MLDriver driver : drivers) {
      algos.addAll(driver.getAlgoNames());
    }
    return algos;
  }

  /*
   * (non-Javadoc)
   *
   * @see org.apache.lens.ml.LensML#getAlgoForName(java.lang.String)
   */
  public MLAlgo getAlgoForName(String algorithm) throws LensException {
    for (MLDriver driver : drivers) {
      if (driver.isAlgoSupported(algorithm)) {
        return driver.getAlgoInstance(algorithm);
      }
    }
    throw new LensException("Algo not supported " + algorithm);
  }

  /*
   * (non-Javadoc)
   *
   * @see org.apache.lens.ml.LensML#train(java.lang.String, java.lang.String, java.lang.String[])
   */
  public String train(String table, String algorithm, String[] args) throws LensException {
    MLAlgo algo = getAlgoForName(algorithm);

    String modelId = UUID.randomUUID().toString();

    LOG.info("Begin training model " + modelId + ", algo=" + algorithm + ", table=" + table + ", params="
      + Arrays.toString(args));

    String database = null;
    if (SessionState.get() != null) {
      database = SessionState.get().getCurrentDatabase();
    } else {
      database = "default";
    }

    MLModel model = algo.train(toLensConf(conf), database, table, modelId, args);

    LOG.info("Done training model: " + modelId);

    model.setCreatedAt(new Date());
    model.setAlgoName(algorithm);

    Path modelLocation = null;
    try {
      modelLocation = persistModel(model);
      LOG.info("Model saved: " + modelId + ", algo: " + algorithm + ", path: " + modelLocation);
      return model.getId();
    } catch (IOException e) {
      throw new LensException("Error saving model " + modelId + " for algo " + algorithm, e);
    }
  }

  /**
   * Gets the algo dir.
   *
   * @param algoName the algo name
   * @return the algo dir
   * @throws IOException Signals that an I/O exception has occurred.
   */
  private Path getAlgoDir(String algoName) throws IOException {
    String modelSaveBaseDir = conf.get(ModelLoader.MODEL_PATH_BASE_DIR, ModelLoader.MODEL_PATH_BASE_DIR_DEFAULT);
    return new Path(new Path(modelSaveBaseDir), algoName);
  }

  /**
   * Persist model.
   *
   * @param model the model
   * @return the path
   * @throws IOException Signals that an I/O exception has occurred.
   */
  private Path persistModel(MLModel model) throws IOException {
    // Get model save path
    Path algoDir = getAlgoDir(model.getAlgoName());
    FileSystem fs = algoDir.getFileSystem(conf);

    if (!fs.exists(algoDir)) {
      fs.mkdirs(algoDir);
    }

    Path modelSavePath = new Path(algoDir, model.getId());
    ObjectOutputStream outputStream = null;

    try {
      outputStream = new ObjectOutputStream(fs.create(modelSavePath, false));
      outputStream.writeObject(model);
      outputStream.flush();
    } catch (IOException io) {
      LOG.error("Error saving model " + model.getId() + " reason: " + io.getMessage());
      throw io;
    } finally {
      IOUtils.closeQuietly(outputStream);
    }
    return modelSavePath;
  }

  /*
   * (non-Javadoc)
   *
   * @see org.apache.lens.ml.LensML#getModels(java.lang.String)
   */
  public List getModels(String algorithm) throws LensException {
    try {
      Path algoDir = getAlgoDir(algorithm);
      FileSystem fs = algoDir.getFileSystem(conf);
      if (!fs.exists(algoDir)) {
        return null;
      }

      List models = new ArrayList();

      for (FileStatus stat : fs.listStatus(algoDir)) {
        models.add(stat.getPath().getName());
      }

      if (models.isEmpty()) {
        return null;
      }

      return models;
    } catch (IOException ioex) {
      throw new LensException(ioex);
    }
  }

  /*
   * (non-Javadoc)
   *
   * @see org.apache.lens.ml.LensML#getModel(java.lang.String, java.lang.String)
   */
  public MLModel getModel(String algorithm, String modelId) throws LensException {
    try {
      return ModelLoader.loadModel(conf, algorithm, modelId);
    } catch (IOException e) {
      throw new LensException(e);
    }
  }

  /**
   * Inits the.
   *
   * @param hiveConf the hive conf
   */
  public synchronized void init(HiveConf hiveConf) {
    this.conf = hiveConf;

    // Get all the drivers
    String[] driverClasses = hiveConf.getStrings("lens.ml.drivers");

    if (driverClasses == null || driverClasses.length == 0) {
      throw new RuntimeException("No ML Drivers specified in conf");
    }

    LOG.info("Loading drivers " + Arrays.toString(driverClasses));
    drivers = new ArrayList(driverClasses.length);

    for (String driverClass : driverClasses) {
      Class cls;
      try {
        cls = Class.forName(driverClass);
      } catch (ClassNotFoundException e) {
        LOG.error("Driver class not found " + driverClass);
        continue;
      }

      if (!MLDriver.class.isAssignableFrom(cls)) {
        LOG.warn("Not a driver class " + driverClass);
        continue;
      }

      try {
        Class mlDriverClass = (Class) cls;
        MLDriver driver = mlDriverClass.newInstance();
        driver.init(toLensConf(conf));
        drivers.add(driver);
        LOG.info("Added driver " + driverClass);
      } catch (Exception e) {
        LOG.error("Failed to create driver " + driverClass + " reason: " + e.getMessage(), e);
      }
    }
    if (drivers.isEmpty()) {
      throw new RuntimeException("No ML drivers loaded");
    }

    LOG.info("Inited ML service");
  }

  /**
   * Start.
   */
  public synchronized void start() {
    for (MLDriver driver : drivers) {
      try {
        if (driver instanceof SparkMLDriver && sparkContext != null) {
          ((SparkMLDriver) driver).useSparkContext(sparkContext);
        }
        driver.start();
      } catch (LensException e) {
        LOG.error("Failed to start driver " + driver, e);
      }
    }

    udfStatusExpirySvc = Executors.newSingleThreadScheduledExecutor();
    udfStatusExpirySvc.scheduleAtFixedRate(new UDFStatusExpiryRunnable(), 60, 60, TimeUnit.SECONDS);

    LOG.info("Started ML service");
  }

  /**
   * Stop.
   */
  public synchronized void stop() {
    for (MLDriver driver : drivers) {
      try {
        driver.stop();
      } catch (LensException e) {
        LOG.error("Failed to stop driver " + driver, e);
      }
    }
    drivers.clear();
    udfStatusExpirySvc.shutdownNow();
    LOG.info("Stopped ML service");
  }

  public synchronized HiveConf getHiveConf() {
    return conf;
  }

  /**
   * Clear models.
   */
  public void clearModels() {
    ModelLoader.clearCache();
  }

  /*
   * (non-Javadoc)
   *
   * @see org.apache.lens.ml.LensML#getModelPath(java.lang.String, java.lang.String)
   */
  public String getModelPath(String algorithm, String modelID) {
    return ModelLoader.getModelLocation(conf, algorithm, modelID).toString();
  }

  /*
   * (non-Javadoc)
   *
   * @see org.apache.lens.ml.LensML#testModel(org.apache.lens.api.LensSessionHandle, java.lang.String, java.lang.String,
   * java.lang.String)
   */
  @Override
  public MLTestReport testModel(LensSessionHandle session, String table, String algorithm, String modelID,
    String outputTable) throws LensException {
    return null;
  }

  /**
   * Test a model in embedded mode.
   *
   * @param sessionHandle the session handle
   * @param table         the table
   * @param algorithm     the algorithm
   * @param modelID       the model id
   * @param queryApiUrl   the query api url
   * @return the ML test report
   * @throws LensException the lens exception
   */
  public MLTestReport testModelRemote(LensSessionHandle sessionHandle, String table, String algorithm, String modelID,
    String queryApiUrl, String outputTable) throws LensException {
    return testModel(sessionHandle, table, algorithm, modelID, new RemoteQueryRunner(sessionHandle, queryApiUrl),
      outputTable);
  }

  /**
   * Evaluate a model. Evaluation is done on data selected table from an input table. The model is run as a UDF and its
   * output is inserted into a table with a partition. Each evaluation is given a unique ID. The partition label is
   * associated with this unique ID.
   * 

*

* This call also required a query runner. Query runner is responsible for executing the evaluation query against Lens * server. *

* * @param sessionHandle the session handle * @param table the table * @param algorithm the algorithm * @param modelID the model id * @param queryRunner the query runner * @param outputTable table where test output will be written * @return the ML test report * @throws LensException the lens exception */ public MLTestReport testModel(final LensSessionHandle sessionHandle, String table, String algorithm, String modelID, QueryRunner queryRunner, String outputTable) throws LensException { if (sessionHandle == null) { throw new NullPointerException("Null session not allowed"); } // check if algorithm exists if (!getAlgorithms().contains(algorithm)) { throw new LensException("No such algorithm " + algorithm); } MLModel model; try { model = ModelLoader.loadModel(conf, algorithm, modelID); } catch (IOException e) { throw new LensException(e); } if (model == null) { throw new LensException("Model not found: " + modelID + " algorithm=" + algorithm); } String database = null; if (SessionState.get() != null) { database = SessionState.get().getCurrentDatabase(); } String testID = UUID.randomUUID().toString().replace("-", "_"); final String testTable = outputTable; final String testResultColumn = "prediction_result"; // TODO support error metric UDAFs TableTestingSpec spec = TableTestingSpec.newBuilder().hiveConf(conf) .database(database == null ? "default" : database).inputTable(table).featureColumns(model.getFeatureColumns()) .outputColumn(testResultColumn).lableColumn(model.getLabelColumn()).algorithm(algorithm).modelID(modelID) .outputTable(testTable).testID(testID).build(); String testQuery = spec.getTestQuery(); if (testQuery == null) { throw new LensException("Invalid test spec. " + "table=" + table + " algorithm=" + algorithm + " modelID=" + modelID); } if (!spec.isOutputTableExists()) { LOG.info("Output table '" + testTable + "' does not exist for test algorithm = " + algorithm + " modelid=" + modelID + ", Creating table using query: " + spec.getCreateOutputTableQuery()); // create the output table String createOutputTableQuery = spec.getCreateOutputTableQuery(); queryRunner.runQuery(createOutputTableQuery); LOG.info("Table created " + testTable); } // Check if ML UDF is registered in this session registerPredictUdf(sessionHandle, queryRunner); LOG.info("Running evaluation query " + testQuery); queryRunner.setQueryName("model_test_" + modelID); QueryHandle testQueryHandle = queryRunner.runQuery(testQuery); MLTestReport testReport = new MLTestReport(); testReport.setReportID(testID); testReport.setAlgorithm(algorithm); testReport.setFeatureColumns(model.getFeatureColumns()); testReport.setLabelColumn(model.getLabelColumn()); testReport.setModelID(model.getId()); testReport.setOutputColumn(testResultColumn); testReport.setOutputTable(testTable); testReport.setTestTable(table); testReport.setQueryID(testQueryHandle.toString()); // Save test report persistTestReport(testReport); LOG.info("Saved test report " + testReport.getReportID()); return testReport; } /** * Persist test report. * * @param testReport the test report * @throws LensException the lens exception */ private void persistTestReport(MLTestReport testReport) throws LensException { LOG.info("saving test report " + testReport.getReportID()); try { ModelLoader.saveTestReport(conf, testReport); LOG.info("Saved report " + testReport.getReportID()); } catch (IOException e) { LOG.error("Error saving report " + testReport.getReportID() + " reason: " + e.getMessage()); } } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#getTestReports(java.lang.String) */ public List getTestReports(String algorithm) throws LensException { Path reportBaseDir = new Path(conf.get(ModelLoader.TEST_REPORT_BASE_DIR, ModelLoader.TEST_REPORT_BASE_DIR_DEFAULT)); FileSystem fs = null; try { fs = reportBaseDir.getFileSystem(conf); if (!fs.exists(reportBaseDir)) { return null; } Path algoDir = new Path(reportBaseDir, algorithm); if (!fs.exists(algoDir)) { return null; } List reports = new ArrayList(); for (FileStatus stat : fs.listStatus(algoDir)) { reports.add(stat.getPath().getName()); } return reports; } catch (IOException e) { LOG.error("Error reading report list for " + algorithm, e); return null; } } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#getTestReport(java.lang.String, java.lang.String) */ public MLTestReport getTestReport(String algorithm, String reportID) throws LensException { try { return ModelLoader.loadReport(conf, algorithm, reportID); } catch (IOException e) { throw new LensException(e); } } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#predict(java.lang.String, java.lang.String, java.lang.Object[]) */ public Object predict(String algorithm, String modelID, Object[] features) throws LensException { // Load the model instance MLModel model = getModel(algorithm, modelID); return model.predict(features); } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#deleteModel(java.lang.String, java.lang.String) */ public void deleteModel(String algorithm, String modelID) throws LensException { try { ModelLoader.deleteModel(conf, algorithm, modelID); LOG.info("DELETED model " + modelID + " algorithm=" + algorithm); } catch (IOException e) { LOG.error( "Error deleting model file. algorithm=" + algorithm + " model=" + modelID + " reason: " + e.getMessage(), e); throw new LensException("Unable to delete model " + modelID + " for algorithm " + algorithm, e); } } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#deleteTestReport(java.lang.String, java.lang.String) */ public void deleteTestReport(String algorithm, String reportID) throws LensException { try { ModelLoader.deleteTestReport(conf, algorithm, reportID); LOG.info("DELETED report=" + reportID + " algorithm=" + algorithm); } catch (IOException e) { LOG.error("Error deleting report " + reportID + " algorithm=" + algorithm + " reason: " + e.getMessage(), e); throw new LensException("Unable to delete report " + reportID + " for algorithm " + algorithm, e); } } /* * (non-Javadoc) * * @see org.apache.lens.ml.LensML#getAlgoParamDescription(java.lang.String) */ public Map getAlgoParamDescription(String algorithm) { MLAlgo algo = null; try { algo = getAlgoForName(algorithm); } catch (LensException e) { LOG.error("Error getting algo description : " + algorithm, e); return null; } if (algo instanceof BaseSparkAlgo) { return ((BaseSparkAlgo) algo).getArgUsage(); } return null; } /** * Submit model test query to a remote Lens server. */ class RemoteQueryRunner extends QueryRunner { /** The query api url. */ final String queryApiUrl; /** * Instantiates a new remote query runner. * * @param sessionHandle the session handle * @param queryApiUrl the query api url */ public RemoteQueryRunner(LensSessionHandle sessionHandle, String queryApiUrl) { super(sessionHandle); this.queryApiUrl = queryApiUrl; } /* * (non-Javadoc) * * @see org.apache.lens.ml.TestQueryRunner#runQuery(java.lang.String) */ @Override public QueryHandle runQuery(String query) throws LensException { // Create jersey client for query endpoint Client client = ClientBuilder.newBuilder().register(MultiPartFeature.class).build(); WebTarget target = client.target(queryApiUrl); final FormDataMultiPart mp = new FormDataMultiPart(); mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("sessionid").build(), sessionHandle, MediaType.APPLICATION_XML_TYPE)); mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("query").build(), query)); mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("operation").build(), "execute")); LensConf lensConf = new LensConf(); lensConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_SET, false + ""); lensConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_INDRIVER, false + ""); mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("conf").fileName("conf").build(), lensConf, MediaType.APPLICATION_XML_TYPE)); final QueryHandle handle = target.request().post(Entity.entity(mp, MediaType.MULTIPART_FORM_DATA_TYPE), QueryHandle.class); LensQuery ctx = target.path(handle.toString()).queryParam("sessionid", sessionHandle).request() .get(LensQuery.class); QueryStatus stat = ctx.getStatus(); while (!stat.finished()) { ctx = target.path(handle.toString()).queryParam("sessionid", sessionHandle).request().get(LensQuery.class); stat = ctx.getStatus(); try { Thread.sleep(500); } catch (InterruptedException e) { throw new LensException(e); } } if (stat.getStatus() != QueryStatus.Status.SUCCESSFUL) { throw new LensException("Query failed " + ctx.getQueryHandle().getHandleId() + " reason:" + stat.getErrorMessage()); } return ctx.getQueryHandle(); } } /** * To lens conf. * * @param conf the conf * @return the lens conf */ private LensConf toLensConf(HiveConf conf) { LensConf lensConf = new LensConf(); lensConf.getProperties().putAll(conf.getValByRegex(".*")); return lensConf; } protected void registerPredictUdf(LensSessionHandle sessionHandle, QueryRunner queryRunner) throws LensException { if (isUdfRegisterd(sessionHandle)) { // Already registered, nothing to do return; } LOG.info("Registering UDF for session " + sessionHandle.getPublicId().toString()); String regUdfQuery = "CREATE TEMPORARY FUNCTION " + HiveMLUDF.UDF_NAME + " AS '" + HiveMLUDF.class .getCanonicalName() + "'"; queryRunner.setQueryName("register_predict_udf_" + sessionHandle.getPublicId().toString()); QueryHandle udfQuery = queryRunner.runQuery(regUdfQuery); LOG.info("udf query handle is " + udfQuery); predictUdfStatus.put(sessionHandle, true); LOG.info("Predict UDF registered for session " + sessionHandle.getPublicId().toString()); } protected boolean isUdfRegisterd(LensSessionHandle sessionHandle) { return predictUdfStatus.containsKey(sessionHandle); } /** * Periodically check if sessions have been closed, and clear UDF registered status. */ private class UDFStatusExpiryRunnable implements Runnable { public void run() { try { SessionService sessionService = (SessionService) MLUtils.getServiceProvider().getService(SessionService.NAME); // Clear status of sessions which are closed. List sessions = new ArrayList(predictUdfStatus.keySet()); for (LensSessionHandle sessionHandle : sessions) { if (!sessionService.isOpen(sessionHandle)) { LOG.info("Session closed, removing UDF status: " + sessionHandle); predictUdfStatus.remove(sessionHandle); } } } catch (Exception exc) { LOG.warn("Error clearing UDF statuses", exc); } } } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy