
prerna.reactor.frame.r.analytics.RunClassificationReactor Maven / Gradle / Ivy
The newest version!
package prerna.reactor.frame.r.analytics;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.List;
import java.util.Map;
import java.util.Vector;
import org.apache.logging.log4j.Logger;
import prerna.ds.OwlTemporalEngineMeta;
import prerna.ds.r.RDataTable;
import prerna.ds.r.RSyntaxHelper;
import prerna.query.interpreters.RInterpreter;
import prerna.query.querystruct.SelectQueryStruct;
import prerna.query.querystruct.selectors.QueryColumnSelector;
import prerna.query.querystruct.transform.QSAliasToPhysicalConverter;
import prerna.reactor.frame.r.AbstractRFrameReactor;
import prerna.sablecc2.om.GenRowStruct;
import prerna.sablecc2.om.PixelDataType;
import prerna.sablecc2.om.PixelOperationType;
import prerna.sablecc2.om.ReactorKeysEnum;
import prerna.sablecc2.om.nounmeta.NounMetadata;
import prerna.util.Utility;
import prerna.util.usertracking.AnalyticsTrackerHelper;
import prerna.util.usertracking.UserTrackerFactory;
public class RunClassificationReactor extends AbstractRFrameReactor {
/**
* RunClassification(classify=[Species],attributes=["PetalLength","PetalWidth","SepalLength","SepalWidth"], panel=[0])
* RunClassification(classify=[race],attributes=["age","workclass","education","marital_status","relationship","sex","capital_gain","capital_loss","income"], panel=[0])
*/
private static final String CLASS_NAME = RunClassificationReactor.class.getName();
private static final String CLASSIFICATION_COLUMN = "classify";
public RunClassificationReactor() {
this.keysToGet = new String[] { CLASSIFICATION_COLUMN, ReactorKeysEnum.ATTRIBUTES.getKey(),
ReactorKeysEnum.PANEL.getKey() };
}
@Override
public NounMetadata execute() {
Logger logger = getLogger(CLASS_NAME);
init();
String[] packages = new String[] { "data.table", "partykit", "dplyr", "naniar" };
this.rJavaTranslator.checkPackages(packages);
RDataTable frame = (RDataTable) getFrame();
OwlTemporalEngineMeta meta = this.getFrame().getMetaData();
String dtName = frame.getName();
boolean implicitFilter = false;
String dtNameIF = "dtFiltered" + Utility.getRandomString(6);
StringBuilder rsb = new StringBuilder();
// load packages
rsb.append("library('partykit');library('naniar');");
// figure out inputs
String predictionCol = getClassificationColumn();
String predictionCol_R = "predictionCol" + Utility.getRandomString(8);
rsb.append(predictionCol_R + "<- \"" + predictionCol + "\";");
List attributes = getColumns();
if (attributes.contains(predictionCol)) {
attributes.remove(predictionCol);
}
if(attributes.isEmpty()) {
throw new IllegalArgumentException("Must define at least one attribute that is not the dimension to classify");
}
String attributes_R = "attributes" + Utility.getRandomString(8);
rsb.append(attributes_R + "<- " + RSyntaxHelper.createStringRColVec(attributes.toArray()) + ";");
// check if there are filters on the frame. if so then need to run algorithm on subsetted data
if(!frame.getFrameFilters().isEmpty()) {
// create a new qs to retrieve filtered frame
SelectQueryStruct qs = new SelectQueryStruct();
List selectedCols = new ArrayList(attributes);
selectedCols.add(predictionCol);
for(String s : selectedCols) {
qs.addSelector(new QueryColumnSelector(s));
}
qs.setImplicitFilters(frame.getFrameFilters());
qs = QSAliasToPhysicalConverter.getPhysicalQs(qs, meta);
RInterpreter interp = new RInterpreter();
interp.setQueryStruct(qs);
interp.setDataTableName(dtName);
interp.setColDataTypes(meta.getHeaderToTypeMap());
String query = interp.composeQuery();
this.rJavaTranslator.runR(dtNameIF + "<- {" + query + "}");
implicitFilter = true;
//cleanup the temp r variable in the query var
this.rJavaTranslator.runR("rm(" + query.split(" <-")[0] + ");gc();");
}
String targetDt = implicitFilter ? dtNameIF : dtName;
//validate that the count of unique values in the instance column != number of rows in the frame
int nrows = frame.getNumRows(targetDt);
int uniqInstCount = this.rJavaTranslator.getInt("if (is.factor(" + targetDt + "$" + predictionCol + ")) "
+ "length(levels(" + targetDt + "$" + predictionCol + ")) else length(unique(" + targetDt + "$" + predictionCol + "));");
if (nrows == uniqInstCount) {
throw new IllegalArgumentException("Values in the column to classify are all unique; classification algorithm is not applicable.");
}
// clustering r script
String classificationScriptFilePath = getBaseFolder() + "\\R\\AnalyticsRoutineScripts\\Classification.R";
classificationScriptFilePath = classificationScriptFilePath.replace("\\", "/");
rsb.append("source(\"" + classificationScriptFilePath + "\");");
String outputList_R = "outputList" + Utility.getRandomString(8);
// set call to R function
rsb.append(outputList_R + " <- getCTree( " + targetDt + "," + predictionCol_R + "," + attributes_R + ");");
// execute R
this.rJavaTranslator.runR(rsb.toString());
String[] predictors = this.rJavaTranslator.getStringArray(outputList_R + "$predictors;");
String accuracy = this.rJavaTranslator.getString(outputList_R + "$accuracy;");
String[] ctreeArray = this.rJavaTranslator.getStringArray(outputList_R + "$tree;");
//// clean up r temp variables
StringBuilder cleanUpScript = new StringBuilder();
cleanUpScript.append("rm(" + outputList_R + "," + predictionCol_R + "," + attributes_R + "," + dtNameIF + ",getCTree,getUsefulPredictors);");
cleanUpScript.append("gc();");
this.rJavaTranslator.runR(cleanUpScript.toString());
if (ctreeArray == null || ctreeArray.length == 0) {
Map vizData = new HashMap();
vizData.put("name", "Decision Tree For " + predictionCol);
vizData.put("layout", "Dendrogram");
vizData.put("panelId", getPanelId());
// make an empty map
Map classificationMap = new HashMap();
classificationMap.put("No Tree Generated", new HashMap());
vizData.put("children", classificationMap);
NounMetadata noun = new NounMetadata(vizData, PixelDataType.CUSTOM_DATA_STRUCTURE, PixelOperationType.VIZ_OUTPUT);
noun.addAdditionalReturn(
new NounMetadata("A decision tree could not be constructed for the requested dataset. Please retry with different data points.",
PixelDataType.CONST_STRING, PixelOperationType.ERROR));
return noun;
}
Map vizData = new HashMap();
vizData.put("name", "Decision Tree For " + predictionCol);
vizData.put("layout", "Dendrogram");
vizData.put("panelId", getPanelId());
// add the actual data
Map classificationMap = processTreeString(ctreeArray);
vizData.put("children", classificationMap);
// add the accuracy and predictors
List
© 2015 - 2025 Weber Informatics LLC | Privacy Policy