All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datarobot.impl.PredictionClient Maven / Gradle / Ivy

Go to download

This java client library allows you to quickly and easily communicate with the DataRobot AI API to add Machine Learning to your applications.

The newest version!
package com.datarobot.impl;

import com.datarobot.IDataRobotAIClient;
import com.datarobot.IPredictionClient;
import com.datarobot.IAIClient;
import com.datarobot.model.AI;
import com.datarobot.model.Dataset;
import com.datarobot.model.DatasetImportResponse;
import com.datarobot.model.DatasetStreamSource;
import com.datarobot.model.Deployment;
import com.datarobot.model.IDatasetSource;
import com.datarobot.model.Output;
import com.datarobot.model.PredictionList;
import com.datarobot.util.Action;
import com.google.api.client.http.HttpRequest;
import com.google.api.client.http.HttpResponse;

import java.io.*;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;

/**
 * The {@link PredictionClient} object provides access to the prediction
 * endpoints of the DataRobot AI API. This object is not meant to be used
 * directly but through the {@code DataRobotAIClient.predict()} method.
 */
public class PredictionClient implements IPredictionClient {
    private IDataRobotAIClient client;
    private Action httpMessageTransformer = null;

    /**
     * {@link IPredictionClient} based API operations. Users will not need to
     * instantiate this object directly. It can be accessed through
     * {@code DataRobotAIClient#predict}.
     * 
     * @param client {@link DataRobotAIClient}
     */
    public PredictionClient(IDataRobotAIClient client) {
        this.client = client;
    }

    /**
     * internal
     */
    public Action getHttpMessageTransformer() {
        return httpMessageTransformer;
    }

    /**
     * internal
     */
    public void setHttpMessageTransformer(Action httpMessageTransformer) {
        this.httpMessageTransformer = httpMessageTransformer;
    }

    /**
     * Send a prediction request to the specified deployment
     * 
     * @param deployment Who will service the request
     * @param sourceFile The data on which to predict via a filepath on the local
     *                   system
     * 
     * @return {@link PredictionList}
     * 
     * @throws ClientException       when 4xx or 5xx response is received from
     *                               server, or errors in parsing the response.
     * @throws FileNotFoundException when a file with the specified pathname does
     *                               not exist, or if the file does exist but is
     *                               inaccessible for some reason.
     */
    @Override
    public PredictionList deploymentPredict(Deployment deployment, String sourceFile)
            throws ClientException, FileNotFoundException {
        Argument.IsNotNull(deployment, "deployment");
        Argument.IsNotNullOrEmpty(sourceFile, "sourceFile");

        return makePredictionRequest(sourceFile, deployment);
    }

    /**
     * Retrieve AI predictions against data. Note an AI must be trained with
     * {@link AI#learn} or an existing learning session must be added to the AI with
     * {@link AI#addLearningSession} or {@link IAIClient#addLearningSession} before
     * predictions can occur.
     * 
     * @param aiId   The ID of the AI which to predict on
     * @param target The name of the selected target feature to predict
     * @param data   The data on which to predict via a filepath on the local system
     * 
     * @return {@link PredictionList}
     * 
     * @throws ClientException       when 4xx or 5xx response is received from
     *                               server, or errors in parsing the response.
     * @throws FileNotFoundException when a file with the specified pathname does
     *                               not exist, or if the file does exist but is
     *                               inaccessible for some reason.
     */
    @Override
    public PredictionList aiPredict(String aiId, String target, String data)
            throws ClientException, FileNotFoundException {
        Argument.IsNotNullOrEmpty(aiId, "aiId");
        Argument.IsNotNullOrEmpty(target, "target");
        Argument.IsNotNullOrEmpty(data, "data");

        // Retrieving the output for this AI
        Output output = client.ais().getOutput(aiId, target);
        // Creating a deployment from the output
        Deployment deployment = new Deployment(output.getUrl(), output.getDeploymentId(), output.getTarget(),
                output.getDataRobotKey(), output.getModelType());

        return makePredictionRequest(data, deployment);
    }

    private PredictionList makePredictionRequest(String sourceFile, Deployment deployment)
            throws FileNotFoundException, ClientException {
        // datarobot-key needs included within the predictions request to access the
        // deployment url
        Map parameters = new HashMap();
        parameters.put("deployment", deployment);

        File predictionFile = new File(sourceFile);

        return client.getConnection().postStream(PredictionList.class, deployment.getUrl(), parameters, predictionFile,
                "text/csv", this.httpMessageTransformer);

    }

    @Override
    public String toString() {
        return "PredictionClient";
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy