All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lens.ml.server.MLServiceResource Maven / Gradle / Ivy

There is a newer version: 2.7.1
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.lens.ml.server;

import static org.apache.commons.lang.StringUtils.isBlank;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.ws.rs.BadRequestException;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
import javax.ws.rs.GET;
import javax.ws.rs.NotFoundException;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.UriInfo;

import org.apache.lens.api.LensException;
import org.apache.lens.api.LensSessionHandle;
import org.apache.lens.api.StringList;
import org.apache.lens.ml.algo.api.MLModel;
import org.apache.lens.ml.api.MLTestReport;
import org.apache.lens.ml.api.ModelMetadata;
import org.apache.lens.ml.api.TestReport;
import org.apache.lens.ml.impl.ModelLoader;
import org.apache.lens.server.api.LensConfConstants;
import org.apache.lens.server.api.ServiceProvider;
import org.apache.lens.server.api.ServiceProviderFactory;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;

import org.glassfish.jersey.media.multipart.FormDataParam;

/**
 * Machine Learning service.
 */
@Path("/ml")
@Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML})
public class MLServiceResource {

  /** The Constant LOG. */
  public static final Log LOG = LogFactory.getLog(MLServiceResource.class);

  /** The ml service. */
  MLService mlService;

  /** The service provider. */
  ServiceProvider serviceProvider;

  /** The service provider factory. */
  ServiceProviderFactory serviceProviderFactory;

  private static final HiveConf HIVE_CONF;

  /**
   * Message indicating if ML service is up
   */
  public static final String ML_UP_MESSAGE = "ML service is up";

  static {
    HIVE_CONF = new HiveConf();
    // Add default config so that we know the service provider implementation
    HIVE_CONF.addResource("lensserver-default.xml");
    HIVE_CONF.addResource("lens-site.xml");
  }

  /**
   * Instantiates a new ML service resource.
   */
  public MLServiceResource() {
    serviceProviderFactory = getServiceProviderFactory(HIVE_CONF);
  }

  private ServiceProvider getServiceProvider() {
    if (serviceProvider == null) {
      serviceProvider = serviceProviderFactory.getServiceProvider();
    }
    return serviceProvider;
  }

  /**
   * Gets the service provider factory.
   *
   * @param conf the conf
   * @return the service provider factory
   */
  private ServiceProviderFactory getServiceProviderFactory(HiveConf conf) {
    Class spfClass = conf.getClass(LensConfConstants.SERVICE_PROVIDER_FACTORY, ServiceProviderFactory.class);
    try {
      return (ServiceProviderFactory) spfClass.newInstance();
    } catch (InstantiationException e) {
      throw new RuntimeException(e);
    } catch (IllegalAccessException e) {
      throw new RuntimeException(e);
    }
  }

  private MLService getMlService() {
    if (mlService == null) {
      mlService = (MLService) getServiceProvider().getService(MLService.NAME);
    }
    return mlService;
  }

  /**
   * Indicates if ML resource is up
   *
   * @return
   */
  @GET
  public String mlResourceUp() {
    return ML_UP_MESSAGE;
  }

  /**
   * Get a list of algos available
   *
   * @return
   */
  @GET
  @Path("algos")
  public StringList getAlgoNames() {
    List algos = getMlService().getAlgorithms();
    StringList result = new StringList(algos);
    return result;
  }

  /**
   * Gets the human readable param description of an algorithm
   *
   * @param algorithm the algorithm
   * @return the param description
   */
  @GET
  @Path("algos/{algorithm}")
  public StringList getParamDescription(@PathParam("algorithm") String algorithm) {
    Map paramDesc = getMlService().getAlgoParamDescription(algorithm);
    if (paramDesc == null) {
      throw new NotFoundException("Param description not found for " + algorithm);
    }

    List descriptions = new ArrayList();
    for (String key : paramDesc.keySet()) {
      descriptions.add(key + " : " + paramDesc.get(key));
    }
    return new StringList(descriptions);
  }

  /**
   * Get model ID list for a given algorithm.
   *
   * @param algorithm algorithm name
   * @return the models for algo
   * @throws LensException the lens exception
   */
  @GET
  @Path("models/{algorithm}")
  public StringList getModelsForAlgo(@PathParam("algorithm") String algorithm) throws LensException {
    List models = getMlService().getModels(algorithm);
    if (models == null || models.isEmpty()) {
      throw new NotFoundException("No models found for algorithm " + algorithm);
    }
    return new StringList(models);
  }

  /**
   * Get metadata of the model given algorithm and model ID.
   *
   * @param algorithm algorithm name
   * @param modelID   model ID
   * @return model metadata
   * @throws LensException the lens exception
   */
  @GET
  @Path("models/{algorithm}/{modelID}")
  public ModelMetadata getModelMetadata(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID)
    throws LensException {
    MLModel model = getMlService().getModel(algorithm, modelID);
    if (model == null) {
      throw new NotFoundException("Model not found " + modelID + ", algo=" + algorithm);
    }

    ModelMetadata meta = new ModelMetadata(model.getId(), model.getTable(), model.getAlgoName(), StringUtils.join(
      model.getParams(), ' '), model.getCreatedAt().toString(), getMlService().getModelPath(algorithm, modelID),
      model.getLabelColumn(), StringUtils.join(model.getFeatureColumns(), ","));
    return meta;
  }

  /**
   * Delete a model given model ID and algorithm name.
   *
   * @param algorithm the algorithm
   * @param modelID   the model id
   * @return confirmation text
   * @throws LensException the lens exception
   */
  @DELETE
  @Consumes({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML, MediaType.TEXT_PLAIN})
  @Path("models/{algorithm}/{modelID}")
  public String deleteModel(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID)
    throws LensException {
    getMlService().deleteModel(algorithm, modelID);
    return "DELETED model=" + modelID + " algorithm=" + algorithm;
  }

  /**
   * Train a model given an algorithm name and algorithm parameters
   * 

* Following parameters are mandatory and must be passed as part of the form *

*

    *
  1. table - input Hive table to load training data from
  2. *
  3. label - name of the labelled column
  4. *
  5. feature - one entry per feature column. At least one feature column is required
  6. *
*

*

* * @param algorithm algorithm name * @param form form data * @return if model is successfully trained, the model ID will be returned * @throws LensException the lens exception */ @POST @Consumes(MediaType.APPLICATION_FORM_URLENCODED) @Path("{algorithm}/train") public String train(@PathParam("algorithm") String algorithm, MultivaluedMap form) throws LensException { // Check if algo is valid if (getMlService().getAlgoForName(algorithm) == null) { throw new NotFoundException("Algo for algo: " + algorithm + " not found"); } if (isBlank(form.getFirst("table"))) { throw new BadRequestException("table parameter is rquired"); } String table = form.getFirst("table"); if (isBlank(form.getFirst("label"))) { throw new BadRequestException("label parameter is required"); } // Check features List featureNames = form.get("feature"); if (featureNames.size() < 1) { throw new BadRequestException("At least one feature is required"); } List algoArgs = new ArrayList(); Set>> paramSet = form.entrySet(); for (Map.Entry> e : paramSet) { String p = e.getKey(); List values = e.getValue(); if ("algorithm".equals(p) || "table".equals(p)) { continue; } else if ("feature".equals(p)) { for (String feature : values) { algoArgs.add("feature"); algoArgs.add(feature); } } else if ("label".equals(p)) { algoArgs.add("label"); algoArgs.add(values.get(0)); } else { algoArgs.add(p); algoArgs.add(values.get(0)); } } LOG.info("Training table " + table + " with algo " + algorithm + " params=" + algoArgs.toString()); String modelId = getMlService().train(table, algorithm, algoArgs.toArray(new String[]{})); LOG.info("Done training " + table + " modelid = " + modelId); return modelId; } /** * Clear model cache (for admin use). * * @return OK if the cache was cleared */ @DELETE @Path("clearModelCache") @Produces(MediaType.TEXT_PLAIN) public Response clearModelCache() { ModelLoader.clearCache(); LOG.info("Cleared model cache"); return Response.ok("Cleared cache", MediaType.TEXT_PLAIN_TYPE).build(); } /** * Run a test on a model for an algorithm. * * @param algorithm algorithm name * @param modelID model ID * @param table Hive table to run test on * @param session Lens session ID. This session ID will be used to run the test query * @return Test report ID * @throws LensException the lens exception */ @POST @Path("test/{table}/{algorithm}/{modelID}") @Consumes(MediaType.MULTIPART_FORM_DATA) public String test(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID, @PathParam("table") String table, @FormDataParam("sessionid") LensSessionHandle session, @FormDataParam("outputTable") String outputTable) throws LensException { MLTestReport testReport = getMlService().testModel(session, table, algorithm, modelID, outputTable); return testReport.getReportID(); } /** * Get list of reports for a given algorithm. * * @param algoritm the algoritm * @return the reports for algorithm * @throws LensException the lens exception */ @GET @Path("reports/{algorithm}") public StringList getReportsForAlgorithm(@PathParam("algorithm") String algoritm) throws LensException { List reports = getMlService().getTestReports(algoritm); if (reports == null || reports.isEmpty()) { throw new NotFoundException("No test reports found for " + algoritm); } return new StringList(reports); } /** * Get a single test report given the algorithm name and report id. * * @param algorithm the algorithm * @param reportID the report id * @return the test report * @throws LensException the lens exception */ @GET @Path("reports/{algorithm}/{reportID}") public TestReport getTestReport(@PathParam("algorithm") String algorithm, @PathParam("reportID") String reportID) throws LensException { MLTestReport report = getMlService().getTestReport(algorithm, reportID); if (report == null) { throw new NotFoundException("Test report: " + reportID + " not found for algorithm " + algorithm); } TestReport result = new TestReport(report.getTestTable(), report.getOutputTable(), report.getOutputColumn(), report.getLabelColumn(), StringUtils.join(report.getFeatureColumns(), ","), report.getAlgorithm(), report.getModelID(), report.getReportID(), report.getLensQueryID()); return result; } /** * DELETE a report given the algorithm name and report ID. * * @param algorithm the algorithm * @param reportID the report id * @return the string * @throws LensException the lens exception */ @DELETE @Path("reports/{algorithm}/{reportID}") @Consumes({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML, MediaType.TEXT_PLAIN}) public String deleteTestReport(@PathParam("algorithm") String algorithm, @PathParam("reportID") String reportID) throws LensException { getMlService().deleteTestReport(algorithm, reportID); return "DELETED report=" + reportID + " algorithm=" + algorithm; } /** * Predict. * * @param algorithm the algorithm * @param modelID the model id * @param uriInfo the uri info * @return the string * @throws LensException the lens exception */ @GET @Path("/predict/{algorithm}/{modelID}") @Produces({MediaType.APPLICATION_ATOM_XML, MediaType.APPLICATION_JSON}) public String predict(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID, @Context UriInfo uriInfo) throws LensException { // Load the model instance MLModel model = getMlService().getModel(algorithm, modelID); // Get input feature names MultivaluedMap params = uriInfo.getQueryParameters(); String[] features = new String[model.getFeatureColumns().size()]; // Assuming that feature name parameters are same int i = 0; for (String feature : model.getFeatureColumns()) { features[i++] = params.getFirst(feature); } // TODO needs a 'prediction formatter' return getMlService().predict(algorithm, modelID, features).toString(); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy