org.apache.lens.ml.server.MLServiceResource Maven / Gradle / Ivy
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.lens.ml.server;
import static org.apache.commons.lang.StringUtils.isBlank;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.ws.rs.BadRequestException;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
import javax.ws.rs.GET;
import javax.ws.rs.NotFoundException;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.UriInfo;
import org.apache.lens.api.LensException;
import org.apache.lens.api.LensSessionHandle;
import org.apache.lens.api.StringList;
import org.apache.lens.ml.algo.api.MLModel;
import org.apache.lens.ml.api.MLTestReport;
import org.apache.lens.ml.api.ModelMetadata;
import org.apache.lens.ml.api.TestReport;
import org.apache.lens.ml.impl.ModelLoader;
import org.apache.lens.server.api.LensConfConstants;
import org.apache.lens.server.api.ServiceProvider;
import org.apache.lens.server.api.ServiceProviderFactory;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.glassfish.jersey.media.multipart.FormDataParam;
/**
* Machine Learning service.
*/
@Path("/ml")
@Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML})
public class MLServiceResource {
/** The Constant LOG. */
public static final Log LOG = LogFactory.getLog(MLServiceResource.class);
/** The ml service. */
MLService mlService;
/** The service provider. */
ServiceProvider serviceProvider;
/** The service provider factory. */
ServiceProviderFactory serviceProviderFactory;
private static final HiveConf HIVE_CONF;
/**
* Message indicating if ML service is up
*/
public static final String ML_UP_MESSAGE = "ML service is up";
static {
HIVE_CONF = new HiveConf();
// Add default config so that we know the service provider implementation
HIVE_CONF.addResource("lensserver-default.xml");
HIVE_CONF.addResource("lens-site.xml");
}
/**
* Instantiates a new ML service resource.
*/
public MLServiceResource() {
serviceProviderFactory = getServiceProviderFactory(HIVE_CONF);
}
private ServiceProvider getServiceProvider() {
if (serviceProvider == null) {
serviceProvider = serviceProviderFactory.getServiceProvider();
}
return serviceProvider;
}
/**
* Gets the service provider factory.
*
* @param conf the conf
* @return the service provider factory
*/
private ServiceProviderFactory getServiceProviderFactory(HiveConf conf) {
Class> spfClass = conf.getClass(LensConfConstants.SERVICE_PROVIDER_FACTORY, ServiceProviderFactory.class);
try {
return (ServiceProviderFactory) spfClass.newInstance();
} catch (InstantiationException e) {
throw new RuntimeException(e);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
private MLService getMlService() {
if (mlService == null) {
mlService = (MLService) getServiceProvider().getService(MLService.NAME);
}
return mlService;
}
/**
* Indicates if ML resource is up
*
* @return
*/
@GET
public String mlResourceUp() {
return ML_UP_MESSAGE;
}
/**
* Get a list of algos available
*
* @return
*/
@GET
@Path("algos")
public StringList getAlgoNames() {
List algos = getMlService().getAlgorithms();
StringList result = new StringList(algos);
return result;
}
/**
* Gets the human readable param description of an algorithm
*
* @param algorithm the algorithm
* @return the param description
*/
@GET
@Path("algos/{algorithm}")
public StringList getParamDescription(@PathParam("algorithm") String algorithm) {
Map paramDesc = getMlService().getAlgoParamDescription(algorithm);
if (paramDesc == null) {
throw new NotFoundException("Param description not found for " + algorithm);
}
List descriptions = new ArrayList();
for (String key : paramDesc.keySet()) {
descriptions.add(key + " : " + paramDesc.get(key));
}
return new StringList(descriptions);
}
/**
* Get model ID list for a given algorithm.
*
* @param algorithm algorithm name
* @return the models for algo
* @throws LensException the lens exception
*/
@GET
@Path("models/{algorithm}")
public StringList getModelsForAlgo(@PathParam("algorithm") String algorithm) throws LensException {
List models = getMlService().getModels(algorithm);
if (models == null || models.isEmpty()) {
throw new NotFoundException("No models found for algorithm " + algorithm);
}
return new StringList(models);
}
/**
* Get metadata of the model given algorithm and model ID.
*
* @param algorithm algorithm name
* @param modelID model ID
* @return model metadata
* @throws LensException the lens exception
*/
@GET
@Path("models/{algorithm}/{modelID}")
public ModelMetadata getModelMetadata(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID)
throws LensException {
MLModel model = getMlService().getModel(algorithm, modelID);
if (model == null) {
throw new NotFoundException("Model not found " + modelID + ", algo=" + algorithm);
}
ModelMetadata meta = new ModelMetadata(model.getId(), model.getTable(), model.getAlgoName(), StringUtils.join(
model.getParams(), ' '), model.getCreatedAt().toString(), getMlService().getModelPath(algorithm, modelID),
model.getLabelColumn(), StringUtils.join(model.getFeatureColumns(), ","));
return meta;
}
/**
* Delete a model given model ID and algorithm name.
*
* @param algorithm the algorithm
* @param modelID the model id
* @return confirmation text
* @throws LensException the lens exception
*/
@DELETE
@Consumes({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML, MediaType.TEXT_PLAIN})
@Path("models/{algorithm}/{modelID}")
public String deleteModel(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID)
throws LensException {
getMlService().deleteModel(algorithm, modelID);
return "DELETED model=" + modelID + " algorithm=" + algorithm;
}
/**
* Train a model given an algorithm name and algorithm parameters
*
* Following parameters are mandatory and must be passed as part of the form
*
*
* - table - input Hive table to load training data from
* - label - name of the labelled column
* - feature - one entry per feature column. At least one feature column is required
*
*
*
*
* @param algorithm algorithm name
* @param form form data
* @return if model is successfully trained, the model ID will be returned
* @throws LensException the lens exception
*/
@POST
@Consumes(MediaType.APPLICATION_FORM_URLENCODED)
@Path("{algorithm}/train")
public String train(@PathParam("algorithm") String algorithm, MultivaluedMap form)
throws LensException {
// Check if algo is valid
if (getMlService().getAlgoForName(algorithm) == null) {
throw new NotFoundException("Algo for algo: " + algorithm + " not found");
}
if (isBlank(form.getFirst("table"))) {
throw new BadRequestException("table parameter is rquired");
}
String table = form.getFirst("table");
if (isBlank(form.getFirst("label"))) {
throw new BadRequestException("label parameter is required");
}
// Check features
List featureNames = form.get("feature");
if (featureNames.size() < 1) {
throw new BadRequestException("At least one feature is required");
}
List algoArgs = new ArrayList();
Set>> paramSet = form.entrySet();
for (Map.Entry> e : paramSet) {
String p = e.getKey();
List values = e.getValue();
if ("algorithm".equals(p) || "table".equals(p)) {
continue;
} else if ("feature".equals(p)) {
for (String feature : values) {
algoArgs.add("feature");
algoArgs.add(feature);
}
} else if ("label".equals(p)) {
algoArgs.add("label");
algoArgs.add(values.get(0));
} else {
algoArgs.add(p);
algoArgs.add(values.get(0));
}
}
LOG.info("Training table " + table + " with algo " + algorithm + " params=" + algoArgs.toString());
String modelId = getMlService().train(table, algorithm, algoArgs.toArray(new String[]{}));
LOG.info("Done training " + table + " modelid = " + modelId);
return modelId;
}
/**
* Clear model cache (for admin use).
*
* @return OK if the cache was cleared
*/
@DELETE
@Path("clearModelCache")
@Produces(MediaType.TEXT_PLAIN)
public Response clearModelCache() {
ModelLoader.clearCache();
LOG.info("Cleared model cache");
return Response.ok("Cleared cache", MediaType.TEXT_PLAIN_TYPE).build();
}
/**
* Run a test on a model for an algorithm.
*
* @param algorithm algorithm name
* @param modelID model ID
* @param table Hive table to run test on
* @param session Lens session ID. This session ID will be used to run the test query
* @return Test report ID
* @throws LensException the lens exception
*/
@POST
@Path("test/{table}/{algorithm}/{modelID}")
@Consumes(MediaType.MULTIPART_FORM_DATA)
public String test(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID,
@PathParam("table") String table, @FormDataParam("sessionid") LensSessionHandle session,
@FormDataParam("outputTable") String outputTable) throws LensException {
MLTestReport testReport = getMlService().testModel(session, table, algorithm, modelID, outputTable);
return testReport.getReportID();
}
/**
* Get list of reports for a given algorithm.
*
* @param algoritm the algoritm
* @return the reports for algorithm
* @throws LensException the lens exception
*/
@GET
@Path("reports/{algorithm}")
public StringList getReportsForAlgorithm(@PathParam("algorithm") String algoritm) throws LensException {
List reports = getMlService().getTestReports(algoritm);
if (reports == null || reports.isEmpty()) {
throw new NotFoundException("No test reports found for " + algoritm);
}
return new StringList(reports);
}
/**
* Get a single test report given the algorithm name and report id.
*
* @param algorithm the algorithm
* @param reportID the report id
* @return the test report
* @throws LensException the lens exception
*/
@GET
@Path("reports/{algorithm}/{reportID}")
public TestReport getTestReport(@PathParam("algorithm") String algorithm, @PathParam("reportID") String reportID)
throws LensException {
MLTestReport report = getMlService().getTestReport(algorithm, reportID);
if (report == null) {
throw new NotFoundException("Test report: " + reportID + " not found for algorithm " + algorithm);
}
TestReport result = new TestReport(report.getTestTable(), report.getOutputTable(), report.getOutputColumn(),
report.getLabelColumn(), StringUtils.join(report.getFeatureColumns(), ","), report.getAlgorithm(),
report.getModelID(), report.getReportID(), report.getLensQueryID());
return result;
}
/**
* DELETE a report given the algorithm name and report ID.
*
* @param algorithm the algorithm
* @param reportID the report id
* @return the string
* @throws LensException the lens exception
*/
@DELETE
@Path("reports/{algorithm}/{reportID}")
@Consumes({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML, MediaType.TEXT_PLAIN})
public String deleteTestReport(@PathParam("algorithm") String algorithm, @PathParam("reportID") String reportID)
throws LensException {
getMlService().deleteTestReport(algorithm, reportID);
return "DELETED report=" + reportID + " algorithm=" + algorithm;
}
/**
* Predict.
*
* @param algorithm the algorithm
* @param modelID the model id
* @param uriInfo the uri info
* @return the string
* @throws LensException the lens exception
*/
@GET
@Path("/predict/{algorithm}/{modelID}")
@Produces({MediaType.APPLICATION_ATOM_XML, MediaType.APPLICATION_JSON})
public String predict(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID,
@Context UriInfo uriInfo) throws LensException {
// Load the model instance
MLModel> model = getMlService().getModel(algorithm, modelID);
// Get input feature names
MultivaluedMap params = uriInfo.getQueryParameters();
String[] features = new String[model.getFeatureColumns().size()];
// Assuming that feature name parameters are same
int i = 0;
for (String feature : model.getFeatureColumns()) {
features[i++] = params.getFirst(feature);
}
// TODO needs a 'prediction formatter'
return getMlService().predict(algorithm, modelID, features).toString();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy