All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.lucidworks.spark.example.ml.UseML Maven / Gradle / Ivy

package com.lucidworks.spark.example.ml;

import com.lucidworks.spark.SparkApp;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Option;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.StructType;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class UseML implements SparkApp.RDDProcessor {

    @Override
    public String getName() {
        return "use-ml";
    }

    @Override
    public Option[] getOptions() {
        return new Option[0];
    }

    @Override
    public int run(SparkConf conf, CommandLine cli) throws Exception {

        long startMs = System.currentTimeMillis();

        conf.set("spark.ui.enabled", "false");

        SparkSession sparkSession = SparkSession.builder().config(conf).getOrCreate();
        try (JavaSparkContext jsc = new JavaSparkContext(sparkSession.sparkContext())) {

            long diffMs = (System.currentTimeMillis() - startMs);
            System.out.println(">> took " + diffMs + " ms to create SQLContext");

            Map options = new HashMap<>();
            options.put("zkhost", "localhost:9983");
            options.put("collection", "ml20news");
            options.put("query", "content_txt:[* TO *]");
            options.put("fields", "content_txt");

            Dataset solrData = sparkSession.read().format("solr").options(options).load();
            Dataset sample = solrData.sample(false, 0.1d, 5150).select("content_txt");
            List rows = sample.collectAsList();
            System.out.println(">> loaded " + rows.size() + " docs to classify");

            StructType schema = sample.schema();

            CrossValidatorModel cvModel = CrossValidatorModel.load("ml-pipeline-model");
            PipelineModel bestModel = (PipelineModel) cvModel.bestModel();

            int r = 0;
            startMs = System.currentTimeMillis();
            for (Object o : rows) {
                Row next = (Row) o;
                Row oneRow = RowFactory.create(next.getString(0));
                Dataset oneRowDF = sparkSession.createDataFrame(Collections.singletonList(oneRow), schema);
                Dataset scored = bestModel.transform(oneRowDF);
                Object o1 = scored.collectAsList().get(0);
                Row scoredRow = (Row) o1;
                String predictedLabel = scoredRow.getString(scoredRow.fieldIndex("predictedLabel"));

                // an acutal app would save the predictedLabel
                //System.out.println(">> for row["+r+"], model returned "+scoredRows.length+" rows, "+scoredRows[0]);

                r++;
            }
            diffMs = (System.currentTimeMillis() - startMs);
            System.out.println(">> took " + diffMs + " ms to score " + rows.size() + " docs");
        }
        return 0;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy