All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lens.client.LensMLClient Maven / Gradle / Ivy

There is a newer version: 2.7.1
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.lens.client;

import java.io.Closeable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import javax.ws.rs.core.Form;

import org.apache.lens.api.LensException;
import org.apache.lens.api.LensSessionHandle;
import org.apache.lens.ml.algo.api.MLAlgo;
import org.apache.lens.ml.algo.api.MLModel;
import org.apache.lens.ml.api.LensML;
import org.apache.lens.ml.api.MLTestReport;
import org.apache.lens.ml.api.ModelMetadata;
import org.apache.lens.ml.api.TestReport;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;

/**
 * Client side implementation of LensML
 */
public class LensMLClient implements LensML, Closeable {
  private static final Log LOG = LogFactory.getLog(LensMLClient.class);

  /** The client. */
  private LensMLJerseyClient client;

  public LensMLClient(String password) {
    this(new LensClientConfig(), password);
  }

  public LensMLClient(LensClientConfig conf, String password) {
    this(conf, conf.getUser(), password);
  }

  public LensMLClient(String username, String password) {
    this(new LensClientConfig(), username, password);
  }

  public LensMLClient(LensClientConfig conf, String username, String password) {
    this(new LensClient(conf, username, password));
  }

  public LensMLClient(LensClient lensClient) {
    client = new LensMLJerseyClient(lensClient.getConnection(), lensClient
        .getConnection().getSessionHandle());
  }

  /**
   * Get list of available machine learning algorithms
   *
   * @return
   */
  @Override
  public List getAlgorithms() {
    return client.getAlgoNames();
  }

  /**
   * Get user friendly information about parameters accepted by the algorithm.
   *
   * @param algorithm the algorithm
   * @return map of param key to its help message
   */
  @Override
  public Map getAlgoParamDescription(String algorithm) {
    List paramDesc = client.getParamDescriptionOfAlgo(algorithm);
    // convert paramDesc to map
    Map paramDescMap = new LinkedHashMap();
    for (String str : paramDesc) {
      String[] keyHelp = StringUtils.split(str, ":");
      paramDescMap.put(keyHelp[0].trim(), keyHelp[1].trim());
    }
    return paramDescMap;
  }

  /**
   * Get a algo object instance which could be used to generate a model of the given algorithm.
   *
   * @param algorithm the algorithm
   * @return the algo for name
   * @throws LensException the lens exception
   */
  @Override
  public MLAlgo getAlgoForName(String algorithm) throws LensException {
    throw new UnsupportedOperationException("MLAlgo cannot be accessed from client");
  }

  /**
   * Create a model using the given HCatalog table as input. The arguments should contain information needeed to
   * generate the model.
   *
   * @param table     the table
   * @param algorithm the algorithm
   * @param args      the args
   * @return Unique ID of the model created after training is complete
   * @throws LensException the lens exception
   */
  @Override
  public String train(String table, String algorithm, String[] args) throws LensException {
    Form trainParams = new Form();
    trainParams.param("table", table);
    for (int i = 0; i < args.length; i += 2) {
      trainParams.param(args[i], args[i + 1]);
    }
    return client.trainModel(algorithm, trainParams);
  }

  /**
   * Get model IDs for the given algorithm.
   *
   * @param algorithm the algorithm
   * @return the models
   * @throws LensException the lens exception
   */
  @Override
  public List getModels(String algorithm) throws LensException {
    return client.getModelsForAlgorithm(algorithm);
  }

  /**
   * Get a model instance given the algorithm name and model ID.
   *
   * @param algorithm the algorithm
   * @param modelId   the model id
   * @return the model
   * @throws LensException the lens exception
   */
  @Override
  public MLModel getModel(String algorithm, String modelId) throws LensException {
    ModelMetadata metadata = client.getModelMetadata(algorithm, modelId);
    String modelPathURI = metadata.getModelPath();

    ObjectInputStream in = null;
    try {
      URI modelURI = new URI(modelPathURI);
      Path modelPath = new Path(modelURI);
      FileSystem fs = FileSystem.get(modelURI, client.getConf());
      in = new ObjectInputStream(fs.open(modelPath));
      MLModel model = (MLModel) in.readObject();
      return model;
    } catch (IOException e) {
      throw new LensException(e);
    } catch (URISyntaxException e) {
      throw new LensException(e);
    } catch (ClassNotFoundException e) {
      throw new LensException(e);
    } finally {
      if (in != null) {
        try {
          in.close();
        } catch (IOException e) {
          e.printStackTrace();
        }
      }
    }

  }

  /**
   * Get the FS location where model instance is saved.
   *
   * @param algorithm the algorithm
   * @param modelID   the model id
   * @return the model path
   */
  @Override
  public String getModelPath(String algorithm, String modelID) {
    ModelMetadata metadata = client.getModelMetadata(algorithm, modelID);
    return metadata.getModelPath();
  }

  /**
   * Evaluate model by running it against test data contained in the given table.
   *
   * @param session   the session
   * @param table     the table
   * @param algorithm the algorithm
   * @param modelID   the model id
   * @return Test report object containing test output table, and various evaluation metrics
   * @throws LensException the lens exception
   */
  @Override
  public MLTestReport testModel(LensSessionHandle session, String table, String algorithm, String modelID,
    String outputTable) throws LensException {
    String reportID = client.testModel(table, algorithm, modelID, outputTable);
    return getTestReport(algorithm, reportID);
  }

  /**
   * Get test reports for an algorithm.
   *
   * @param algorithm the algorithm
   * @return the test reports
   * @throws LensException the lens exception
   */
  @Override
  public List getTestReports(String algorithm) throws LensException {
    return client.getTestReportsOfAlgorithm(algorithm);
  }

  /**
   * Get a test report by ID.
   *
   * @param algorithm the algorithm
   * @param reportID  the report id
   * @return the test report
   * @throws LensException the lens exception
   */
  @Override
  public MLTestReport getTestReport(String algorithm, String reportID) throws LensException {
    TestReport report = client.getTestReport(algorithm, reportID);
    MLTestReport mlTestReport = new MLTestReport();
    mlTestReport.setAlgorithm(report.getAlgorithm());
    mlTestReport.setFeatureColumns(Arrays.asList(report.getFeatureColumns().split("\\,+")));
    mlTestReport.setLensQueryID(report.getQueryID());
    mlTestReport.setLabelColumn(report.getLabelColumn());
    mlTestReport.setModelID(report.getModelID());
    mlTestReport.setOutputColumn(report.getOutputColumn());
    mlTestReport.setPredictionResultColumn(report.getOutputColumn());
    mlTestReport.setQueryID(report.getQueryID());
    mlTestReport.setReportID(report.getReportID());
    mlTestReport.setTestTable(report.getTestTable());
    return mlTestReport;
  }

  /**
   * Online predict call given a model ID, algorithm name and sample feature values.
   *
   * @param algorithm the algorithm
   * @param modelID   the model id
   * @param features  the features
   * @return prediction result
   * @throws LensException the lens exception
   */
  @Override
  public Object predict(String algorithm, String modelID, Object[] features) throws LensException {
    return getModel(algorithm, modelID).predict(features);
  }

  /**
   * Permanently delete a model instance.
   *
   * @param algorithm the algorithm
   * @param modelID   the model id
   * @throws LensException the lens exception
   */
  @Override
  public void deleteModel(String algorithm, String modelID) throws LensException {
    client.deleteModel(algorithm, modelID);
  }

  /**
   * Permanently delete a test report instance.
   *
   * @param algorithm the algorithm
   * @param reportID  the report id
   * @throws LensException the lens exception
   */
  @Override
  public void deleteTestReport(String algorithm, String reportID) throws LensException {
    client.deleteTestReport(algorithm, reportID);
  }

  /**
   * Close connection
   */
  @Override
  public void close() throws IOException {
    client.close();
  }

  public LensSessionHandle getSessionHandle() {
    return client.getSessionHandle();
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy