All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.insight.ml.spark.mllib.model.ModelHandler Maven / Gradle / Ivy

The newest version!
package com.datastax.insight.ml.spark.mllib.model;

import com.datastax.insight.core.driver.SparkContextBuilder;
import com.datastax.insight.core.entity.Model;
import com.datastax.data.common.hadoop.HDFSUtil;
import com.datastax.insight.spec.RDDOperator;
import com.datastax.insight.core.service.PersistService;
import com.datastax.util.lang.ReflectUtil;
import com.datastax.util.lang.StringUtil;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.SparkContext;
import org.apache.spark.mllib.classification.ClassificationModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.regression.RegressionModel;
import org.apache.spark.mllib.util.Saveable;
import org.apache.spark.sql.SparkSession;

import java.util.Arrays;

public class ModelHandler implements RDDOperator {

    public static Saveable load(String modelId){
//        Model model = PersistService.getModel(modelId);

        Model model = (Model) PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "getModel",
                new String[]{String.class.getTypeName()},
                new Object[]{modelId});

        return innerLoad(model.getModelClass(), model.getPath());
    }

    private static Saveable innerLoad(String modelClass, String modelPath) {
        SparkContext ctx = SparkContextBuilder.getContext();
        return innerLoad(ctx, modelClass, modelPath);
    }

    private static Saveable innerLoad(SparkContext ctx, String modelClass, String modelPath) {

//        Model modelInfo = getModel(Integer.parseInt(path));
//        String modelPath = modelInfo.getPath();
//        Saveable model = null;

        return (Saveable) ReflectUtil.invokeStaticMethod(modelClass, "load",
                new String[] { ctx.getClass().getTypeName(), String.class.getTypeName()},
                new Object[] { ctx, modelPath});

//        if(type.equals("LogisticRegression")) {
//            model = LogisticRegressionModel.load(ctx, modelPath);
//        } else if(type.equals("NaiveBayes")) {
//            model = NaiveBayesModel.load(ctx, modelPath);
//        } else if(type.equals("SVM")) {
//            model = SVMModel.load(ctx, modelPath);
//        } else if(type.equals("BisectingKMeans")) {
//            model = BisectingKMeansModel.load(ctx, modelPath);
//        } else if(type.equals("DistributedLDA")) {
//            model = DistributedLDAModel.load(ctx, modelPath);
//        } else if(type.equals("GaussianMixture")) {
//            model = GaussianMixtureModel.load(ctx, modelPath);
//        } else if(type.equals("KMeans")){
//            model = KMeansModel.load(ctx, modelPath);
//        }  else if(type.equals("LocalLDA")) {
//            model = LocalLDAModel.load(ctx, modelPath);
//        } else if(type.equals("PowerIterationClustering")) {
//            model = PowerIterationClusteringModel.load(ctx, modelPath);
//        } else if(type.equals("StreamingKMeans")) {
//            model = StreamingKMeansModel.load(ctx, modelPath);
//        } else if(type.equals("PrefixSpan")) {
//            model = PrefixSpanModel.load(ctx, modelPath);
//        } else if(type.equals("MatrixFactorization")) {
//            model = MatrixFactorizationModel.load(ctx, modelPath);
//        } else if(type.equals("IsotonicRegression")) {
//            model = IsotonicRegressionModel.load(ctx, modelPath);
//        } else if(type.equals("Lasso")) {
//            model = LassoModel.load(ctx, modelPath);
//        } else if(type.equals("RidgeRegression")) {
//            model = RidgeRegressionModel.load(ctx, modelPath);
//        } else if(type.equals("DecisionTree")) {
//            model = DecisionTreeModel.load(ctx, modelPath);
//        } else if(type.equals("GradientBoostedTrees")) {
//            model = GradientBoostedTreesModel.load(ctx, modelPath);
//        } else if(type.equals("RandomForest")) {
//            model = RandomForestModel.load(ctx, modelPath);
//        }
//        return model;
    }

    public static void save(Saveable saveable, String modelName, String path, boolean replace){

        String modelPath = path + "/" + modelName;

        if(replace){
            deleteHDFS(modelPath);
        }
        saveable.save(SparkContextBuilder.getContext(), modelPath);

//        PersistService.saveModel(modelName, path);
        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "saveModel",
                new String[]{Long.class.getTypeName(),
                        String.class.getTypeName(),
                        String.class.getTypeName(),
                        String.class.getTypeName()},
                new Object[]{PersistService.getFlowId(), modelName, saveable.getClass().getName(), modelPath});
    }

    private static void deleteHDFS(String path){
        String hdfs= StringUtil.substringIndent(path,"hdfs://","/");
        if(hdfs!=null){
            hdfs="hdfs://"+hdfs;
            path=path.replace(hdfs,"");
            Configuration configuration=new Configuration();
            configuration.set("fs.default.name",hdfs);
            try {
                HDFSUtil.delete(configuration,path);
                System.out.println("datastax-Insight Model deleted with sucess, model path -> " + path);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    public static String predict(String modelId, String feature) {
        Saveable model = load(modelId);
        return predict(model, feature);
    }

    public static String[] predict(String modelId, String[] feature) {
        Saveable model = load(modelId);
        return predict(model, feature);
    }

    public static String predict(String modelClass, String modelPath, String feature) {
        Saveable model = innerLoad(modelClass, modelPath);
        return predict(model, feature);
    }

    public static String[] predict(String modelClass, String modelPath, String[] feature) {
        Saveable model = innerLoad(modelClass, modelPath);
        return predict(model, feature);
    }

    public static int[] recommendProducts(String modelClass, String modelPath, int userId, int num) {

        SparkSession spark = SparkSession
                .builder()
                .appName("datastax-insight" + "-" + System.currentTimeMillis())
                .master("local[*]")
                .getOrCreate();

        Saveable model = innerLoad(spark.sparkContext(), modelClass, modelPath);
        MatrixFactorizationModel als = (MatrixFactorizationModel)model;
        int[] result = Arrays.stream(als.recommendProducts(userId, num)).mapToInt(r->r.product()).toArray();

        spark.stop();

        return result;
    }

    public static int[] recommendUsers(String modelClass, String modelPath, int productId, int num) {
        Saveable model = innerLoad(modelClass, modelPath);
        MatrixFactorizationModel als = (MatrixFactorizationModel)model;
        return Arrays.stream(als.recommendUsers(productId, num)).mapToInt(r->r.product()).toArray();
    }

    public static int[] recommendProducts(Saveable model, int userId, int num) {
        MatrixFactorizationModel als = (MatrixFactorizationModel)model;
        return Arrays.stream(als.recommendProducts(userId, num)).mapToInt(r->r.product()).toArray();
    }

    public static int[] recommendUsers(Saveable model, int productId, int num) {
        MatrixFactorizationModel als = (MatrixFactorizationModel)model;
        return Arrays.stream(als.recommendUsers(productId, num)).mapToInt(r->r.user()).toArray();
    }

    private static String predict(Saveable model, String feature) {
        Vector vector = Vectors.dense(Arrays.stream(feature.split(",")).mapToDouble(x->Double.parseDouble(x)).toArray());
        return predict(model, vector);
    }

    private static String predict(Saveable model, Vector feature) {

        if(model instanceof ClassificationModel) {
            return String.valueOf(((ClassificationModel) model).predict(feature));
        } else if(model instanceof RegressionModel) {
            return String.valueOf(((RegressionModel) model).predict(feature));
        }

        throw new UnsupportedOperationException("predict method is no support for Class:" + model.getClass().getTypeName());
    }

    private static String[] predict(Saveable model, String[] features) {
        Vector[] vectors = Arrays.stream(features).map(d-> Vectors.dense(Arrays.stream(d.split(",")).mapToDouble(x->Double.parseDouble(x)).toArray())).toArray(Vector[]::new);
        return predict(model, vectors);
    }

    private static String[] predict(Saveable model, Vector[] features) {
        return Arrays.stream(features).map(v->String.valueOf(predict(model, v))).toArray(String[]::new);
    }

//    private static Model getModel(int id) {
//        List data = (List) Cache.getCache("models");
//        if(data != null) {
//            return data.stream().filter(d->d.getId() == id).findFirst().orElse(null);
//        }
//        return null;
//    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy