All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lens.client.LensMLJerseyClient Maven / Gradle / Ivy

There is a newer version: 2.7.1
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.lens.client;

import java.util.List;
import java.util.Map;

import javax.ws.rs.NotFoundException;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.Form;
import javax.ws.rs.core.MediaType;

import org.apache.lens.api.LensSessionHandle;
import org.apache.lens.api.StringList;
import org.apache.lens.ml.api.ModelMetadata;
import org.apache.lens.ml.api.TestReport;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;

import org.glassfish.jersey.media.multipart.FormDataBodyPart;
import org.glassfish.jersey.media.multipart.FormDataContentDisposition;
import org.glassfish.jersey.media.multipart.FormDataMultiPart;
import org.glassfish.jersey.media.multipart.MultiPartFeature;

/*
 * Client code to invoke server side ML API
 */

/**
 * The Class LensMLJerseyClient.
 */
public class LensMLJerseyClient {
  /** The Constant LENS_ML_RESOURCE_PATH. */
  public static final String LENS_ML_RESOURCE_PATH = "lens.ml.resource.path";

  /** The Constant DEFAULT_ML_RESOURCE_PATH. */
  public static final String DEFAULT_ML_RESOURCE_PATH = "ml";

  /** The Constant LOG. */
  public static final Log LOG = LogFactory.getLog(LensMLJerseyClient.class);

  /** The connection. */
  private final LensConnection connection;

  private final LensSessionHandle sessionHandle;

  /**
   * Instantiates a new lens ml jersey client.
   *
   * @param connection the connection
   */
  public LensMLJerseyClient(LensConnection connection, String password) {
    this.connection = connection;
    connection.open(password);
    this.sessionHandle = null;
  }

  /**
   * Instantiates a new lens ml jersey client.
   *
   * @param connection the connection
   */
  public LensMLJerseyClient(LensConnection connection, LensSessionHandle sessionHandle) {
    this.connection = connection;
    this.sessionHandle = sessionHandle;
  }

  protected WebTarget getMLWebTarget() {
    Client client = ClientBuilder.newBuilder().register(MultiPartFeature.class).build();
    LensConnectionParams connParams = connection.getLensConnectionParams();
    String baseURI = connParams.getBaseConnectionUrl();
    String mlURI = connParams.getConf().get(LENS_ML_RESOURCE_PATH, DEFAULT_ML_RESOURCE_PATH);
    return client.target(baseURI).path(mlURI);
  }

  /**
   * Gets the model metadata.
   *
   * @param algorithm the algorithm
   * @param modelID   the model id
   * @return the model metadata
   */
  public ModelMetadata getModelMetadata(String algorithm, String modelID) {
    try {
      return getMLWebTarget().path("models").path(algorithm).path(modelID).request().get(ModelMetadata.class);
    } catch (NotFoundException exc) {
      return null;
    }
  }

  /**
   * Delete model.
   *
   * @param algorithm the algorithm
   * @param modelID   the model id
   */
  public void deleteModel(String algorithm, String modelID) {
    getMLWebTarget().path("models").path(algorithm).path(modelID).request().delete();
  }

  /**
   * Gets the models for algorithm.
   *
   * @param algorithm the algorithm
   * @return the models for algorithm
   */
  public List getModelsForAlgorithm(String algorithm) {
    try {
      StringList models = getMLWebTarget().path("models").path(algorithm).request().get(StringList.class);
      return models == null ? null : models.getElements();
    } catch (NotFoundException exc) {
      return null;
    }
  }

  public List getAlgoNames() {
    StringList algoNames = getMLWebTarget().path("algos").request().get(StringList.class);
    return algoNames == null ? null : algoNames.getElements();
  }

  /**
   * Train model.
   *
   * @param algorithm the algorithm
   * @param params    the params
   * @return the string
   */
  public String trainModel(String algorithm, Form params) {
    return getMLWebTarget().path(algorithm).path("train").request(MediaType.APPLICATION_JSON_TYPE)
      .post(Entity.entity(params, MediaType.APPLICATION_FORM_URLENCODED_TYPE), String.class);
  }

  /**
   * Test model.
   *
   * @param table       the table
   * @param algorithm   the algorithm
   * @param modelID     the model id
   * @param outputTable the output table name
   * @return the string
   */
  public String testModel(String table, String algorithm, String modelID, String outputTable) {
    WebTarget modelTestTarget = getMLWebTarget().path("test").path(table).path(algorithm).path(modelID);

    FormDataMultiPart mp = new FormDataMultiPart();

    LensSessionHandle sessionHandle = this.sessionHandle == null ? connection.getSessionHandle() : this.sessionHandle;

    mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("sessionid").build(), sessionHandle,
      MediaType.APPLICATION_XML_TYPE));

    mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("outputTable").build(), outputTable));
    return modelTestTarget.request().post(Entity.entity(mp, MediaType.MULTIPART_FORM_DATA_TYPE), String.class);
  }

  /**
   * Gets the test reports of algorithm.
   *
   * @param algorithm the algorithm
   * @return the test reports of algorithm
   */
  public List getTestReportsOfAlgorithm(String algorithm) {
    try {
      StringList list = getMLWebTarget().path("reports").path(algorithm).request().get(StringList.class);
      return list == null ? null : list.getElements();
    } catch (NotFoundException exc) {
      return null;
    }
  }

  /**
   * Gets the test report.
   *
   * @param algorithm the algorithm
   * @param reportID  the report id
   * @return the test report
   */
  public TestReport getTestReport(String algorithm, String reportID) {
    try {
      return getMLWebTarget().path("reports").path(algorithm).path(reportID).request().get(TestReport.class);
    } catch (NotFoundException exc) {
      return null;
    }
  }

  /**
   * Delete test report.
   *
   * @param algorithm the algorithm
   * @param reportID  the report id
   * @return the string
   */
  public String deleteTestReport(String algorithm, String reportID) {
    return getMLWebTarget().path("reports").path(algorithm).path(reportID).request().delete(String.class);
  }

  /**
   * Predict single.
   *
   * @param algorithm the algorithm
   * @param modelID   the model id
   * @param features  the features
   * @return the string
   */
  public String predictSingle(String algorithm, String modelID, Map features) {
    WebTarget target = getMLWebTarget().path("predict").path(algorithm).path(modelID);

    for (Map.Entry entry : features.entrySet()) {
      target.queryParam(entry.getKey(), entry.getValue());
    }

    return target.request().get(String.class);
  }

  /**
   * Gets the param description of algo.
   *
   * @param algorithm the algorithm
   * @return the param description of algo
   */
  public List getParamDescriptionOfAlgo(String algorithm) {
    try {
      StringList paramHelp = getMLWebTarget().path("algos").path(algorithm).request(MediaType.APPLICATION_XML)
        .get(StringList.class);
      return paramHelp.getElements();
    } catch (NotFoundException exc) {
      return null;
    }
  }

  public Configuration getConf() {
    return connection.getLensConnectionParams().getConf();
  }

  public void close() {
    try {
      connection.close();
    } catch (Exception exc) {
      LOG.error("Error closing connection", exc);
    }
  }

  public LensSessionHandle getSessionHandle() {
    if (sessionHandle != null) {
      return sessionHandle;
    }
    return connection.getSessionHandle();
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy