com.datastax.insight.ml.spark.mllib.model.ModelHandler Maven / Gradle / Ivy
The newest version!
package com.datastax.insight.ml.spark.mllib.model;
import com.datastax.insight.core.driver.SparkContextBuilder;
import com.datastax.insight.core.entity.Model;
import com.datastax.data.common.hadoop.HDFSUtil;
import com.datastax.insight.spec.RDDOperator;
import com.datastax.insight.core.service.PersistService;
import com.datastax.util.lang.ReflectUtil;
import com.datastax.util.lang.StringUtil;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.SparkContext;
import org.apache.spark.mllib.classification.ClassificationModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.regression.RegressionModel;
import org.apache.spark.mllib.util.Saveable;
import org.apache.spark.sql.SparkSession;
import java.util.Arrays;
public class ModelHandler implements RDDOperator {
public static Saveable load(String modelId){
// Model model = PersistService.getModel(modelId);
Model model = (Model) PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
"getModel",
new String[]{String.class.getTypeName()},
new Object[]{modelId});
return innerLoad(model.getModelClass(), model.getPath());
}
private static Saveable innerLoad(String modelClass, String modelPath) {
SparkContext ctx = SparkContextBuilder.getContext();
return innerLoad(ctx, modelClass, modelPath);
}
private static Saveable innerLoad(SparkContext ctx, String modelClass, String modelPath) {
// Model modelInfo = getModel(Integer.parseInt(path));
// String modelPath = modelInfo.getPath();
// Saveable model = null;
return (Saveable) ReflectUtil.invokeStaticMethod(modelClass, "load",
new String[] { ctx.getClass().getTypeName(), String.class.getTypeName()},
new Object[] { ctx, modelPath});
// if(type.equals("LogisticRegression")) {
// model = LogisticRegressionModel.load(ctx, modelPath);
// } else if(type.equals("NaiveBayes")) {
// model = NaiveBayesModel.load(ctx, modelPath);
// } else if(type.equals("SVM")) {
// model = SVMModel.load(ctx, modelPath);
// } else if(type.equals("BisectingKMeans")) {
// model = BisectingKMeansModel.load(ctx, modelPath);
// } else if(type.equals("DistributedLDA")) {
// model = DistributedLDAModel.load(ctx, modelPath);
// } else if(type.equals("GaussianMixture")) {
// model = GaussianMixtureModel.load(ctx, modelPath);
// } else if(type.equals("KMeans")){
// model = KMeansModel.load(ctx, modelPath);
// } else if(type.equals("LocalLDA")) {
// model = LocalLDAModel.load(ctx, modelPath);
// } else if(type.equals("PowerIterationClustering")) {
// model = PowerIterationClusteringModel.load(ctx, modelPath);
// } else if(type.equals("StreamingKMeans")) {
// model = StreamingKMeansModel.load(ctx, modelPath);
// } else if(type.equals("PrefixSpan")) {
// model = PrefixSpanModel.load(ctx, modelPath);
// } else if(type.equals("MatrixFactorization")) {
// model = MatrixFactorizationModel.load(ctx, modelPath);
// } else if(type.equals("IsotonicRegression")) {
// model = IsotonicRegressionModel.load(ctx, modelPath);
// } else if(type.equals("Lasso")) {
// model = LassoModel.load(ctx, modelPath);
// } else if(type.equals("RidgeRegression")) {
// model = RidgeRegressionModel.load(ctx, modelPath);
// } else if(type.equals("DecisionTree")) {
// model = DecisionTreeModel.load(ctx, modelPath);
// } else if(type.equals("GradientBoostedTrees")) {
// model = GradientBoostedTreesModel.load(ctx, modelPath);
// } else if(type.equals("RandomForest")) {
// model = RandomForestModel.load(ctx, modelPath);
// }
// return model;
}
public static void save(Saveable saveable, String modelName, String path, boolean replace){
String modelPath = path + "/" + modelName;
if(replace){
deleteHDFS(modelPath);
}
saveable.save(SparkContextBuilder.getContext(), modelPath);
// PersistService.saveModel(modelName, path);
PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
"saveModel",
new String[]{Long.class.getTypeName(),
String.class.getTypeName(),
String.class.getTypeName(),
String.class.getTypeName()},
new Object[]{PersistService.getFlowId(), modelName, saveable.getClass().getName(), modelPath});
}
private static void deleteHDFS(String path){
String hdfs= StringUtil.substringIndent(path,"hdfs://","/");
if(hdfs!=null){
hdfs="hdfs://"+hdfs;
path=path.replace(hdfs,"");
Configuration configuration=new Configuration();
configuration.set("fs.default.name",hdfs);
try {
HDFSUtil.delete(configuration,path);
System.out.println("datastax-Insight Model deleted with sucess, model path -> " + path);
} catch (Exception e) {
e.printStackTrace();
}
}
}
public static String predict(String modelId, String feature) {
Saveable model = load(modelId);
return predict(model, feature);
}
public static String[] predict(String modelId, String[] feature) {
Saveable model = load(modelId);
return predict(model, feature);
}
public static String predict(String modelClass, String modelPath, String feature) {
Saveable model = innerLoad(modelClass, modelPath);
return predict(model, feature);
}
public static String[] predict(String modelClass, String modelPath, String[] feature) {
Saveable model = innerLoad(modelClass, modelPath);
return predict(model, feature);
}
public static int[] recommendProducts(String modelClass, String modelPath, int userId, int num) {
SparkSession spark = SparkSession
.builder()
.appName("datastax-insight" + "-" + System.currentTimeMillis())
.master("local[*]")
.getOrCreate();
Saveable model = innerLoad(spark.sparkContext(), modelClass, modelPath);
MatrixFactorizationModel als = (MatrixFactorizationModel)model;
int[] result = Arrays.stream(als.recommendProducts(userId, num)).mapToInt(r->r.product()).toArray();
spark.stop();
return result;
}
public static int[] recommendUsers(String modelClass, String modelPath, int productId, int num) {
Saveable model = innerLoad(modelClass, modelPath);
MatrixFactorizationModel als = (MatrixFactorizationModel)model;
return Arrays.stream(als.recommendUsers(productId, num)).mapToInt(r->r.product()).toArray();
}
public static int[] recommendProducts(Saveable model, int userId, int num) {
MatrixFactorizationModel als = (MatrixFactorizationModel)model;
return Arrays.stream(als.recommendProducts(userId, num)).mapToInt(r->r.product()).toArray();
}
public static int[] recommendUsers(Saveable model, int productId, int num) {
MatrixFactorizationModel als = (MatrixFactorizationModel)model;
return Arrays.stream(als.recommendUsers(productId, num)).mapToInt(r->r.user()).toArray();
}
private static String predict(Saveable model, String feature) {
Vector vector = Vectors.dense(Arrays.stream(feature.split(",")).mapToDouble(x->Double.parseDouble(x)).toArray());
return predict(model, vector);
}
private static String predict(Saveable model, Vector feature) {
if(model instanceof ClassificationModel) {
return String.valueOf(((ClassificationModel) model).predict(feature));
} else if(model instanceof RegressionModel) {
return String.valueOf(((RegressionModel) model).predict(feature));
}
throw new UnsupportedOperationException("predict method is no support for Class:" + model.getClass().getTypeName());
}
private static String[] predict(Saveable model, String[] features) {
Vector[] vectors = Arrays.stream(features).map(d-> Vectors.dense(Arrays.stream(d.split(",")).mapToDouble(x->Double.parseDouble(x)).toArray())).toArray(Vector[]::new);
return predict(model, vectors);
}
private static String[] predict(Saveable model, Vector[] features) {
return Arrays.stream(features).map(v->String.valueOf(predict(model, v))).toArray(String[]::new);
}
// private static Model getModel(int id) {
// List data = (List) Cache.getCache("models");
// if(data != null) {
// return data.stream().filter(d->d.getId() == id).findFirst().orElse(null);
// }
// return null;
// }
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy