All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scripts.algorithms.random-forest-predict.dml Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# 
#   http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

#  
# THIS SCRIPT COMPUTES LABEL PREDICTIONS MEANT FOR USE WITH A RANDOM FOREST MODEL ON A HELD OUT TEST SET 
# OR FOR COMPUTING THE OUT-OF-BAG ERROR ON THE TRAINING SET.
#
# INPUT         PARAMETERS:
# ---------------------------------------------------------------------------------------------
# NAME          TYPE     DEFAULT      MEANING
# ---------------------------------------------------------------------------------------------
# X             String   ---          Location to read test feature matrix or training feature matrix for computing Out-Of-Bag error; 
#									  note that X needs to be both recoded and dummy coded 
# Y	 		    String   " "		  Location to read true label matrix Y if requested; note that Y needs to be both recoded and dummy coded
# R   	  		String   " "	      Location to read the matrix R which for each feature in X contains the following information 
#										- R[,1]: column ids
#										- R[,2]: start indices 
#										- R[,3]: end indices
#									  If R is not provided by default all variables are assumed to be scale
# M             String 	 ---	   	  Location to read matrix M containing the learned tree i the following format
#								 		- M[1,j]: id of node j (in a complete binary tree)
#										- M[2,j]: tree id 
#	 									- M[3,j]: Offset (no. of columns) to left child of j if j is an internal node, otherwise 0
#	 									- M[4,j]: Feature index of the feature that node j looks at if j is an internal node, otherwise 0
#	 									- M[5,j]: Type of the feature that node j looks at if j is an internal node: 1 for scale and 2 for categorical features, 
#		     									  otherwise the label that leaf node j is supposed to predict
#	 									- M[6,j]: If j is an internal node: 1 if the feature chosen for j is scale, otherwise the size of the subset of values 
#			 									  stored in rows 7,8,... if j is categorical 
#						 						  If j is a leaf node: number of misclassified samples reaching at node j 
#	 									- M[7:,j]: If j is an internal node: Threshold the example's feature value is compared to is stored at M[7,j] 
#							   					   if the feature chosen for j is scale, otherwise if the feature chosen for j is categorical rows 7,8,... 
#												   depict the value subset chosen for j
#	          									   If j is a leaf node 1 if j is impure and the number of samples at j > threshold, otherwise 0
# C 			String   " "		  Location to read the counts matrix containing the number of times samples are chosen in each tree of the random forest
# P				String   ---		  Location to store the label predictions for X
# A     		String   " "          Location to store the test accuracy (%) for the prediction if requested
# OOB 			String   " "		  If C is provided location to store the Out-Of-Bag (OOB) error of the learned model 
# CM     		String   " "		  Location to store the confusion matrix if requested 
# fmt     	    String   "text"       The output format of the output, such as "text" or "csv"
# ---------------------------------------------------------------------------------------------
# OUTPUT: 
#	1- Matrix Y containing the predicted labels for X 
#   2- Test accuracy if requested
#   3- Confusion matrix C if requested
# -------------------------------------------------------------------------------------------
# HOW TO INVOKE THIS SCRIPT - EXAMPLE:
# hadoop jar SystemML.jar -f random-forest-predict.dml -nvargs X=INPUT_DIR/X Y=INPUT_DIR/Y R=INPUT_DIR/R M=INPUT_DIR/model P=OUTPUT_DIR/predictions
#														A=OUTPUT_DIR/accurcay CM=OUTPUT_DIR/confusion fmt=csv

fileX = $X;
fileM = $M;
fileP = $P;
fileY = ifdef ($Y, " ");
fileR = ifdef ($R, " ");
fileC = ifdef ($C, " ");
fileOOB = ifdef ($OOB, " ");
fileCM = ifdef ($CM, " ");
fileA = ifdef ($A, " ");
fmtO = ifdef ($fmt, "text");
X = read (fileX);
M = read (fileM);

num_records = nrow (X);
Y_predicted = matrix (0, rows = num_records, cols = 1);
num_trees  = max (M[2,]);
num_labels = max (M[5,]);
num_nodes_per_tree = aggregate (target = t (M[2,]), groups = t (M[2,]), fn = "count");
num_nodes_per_tree_cum = cumsum (num_nodes_per_tree);

R_cat = matrix (0, rows = 1, cols = 1);
R_scale = matrix (0, rows = 1, cols = 1);

if (fileR != " ") {
	R = read (fileR);
	dummy_coded = (R[,2] != R[,3]);
	R_scale = removeEmpty (target = R[,2] * (1 - dummy_coded), margin = "rows");
	R_cat = removeEmpty (target = R[,2:3] * dummy_coded, margin = "rows");
} else { # only scale features available
	R_scale = seq (1, ncol (X));
}

if (fileC != " ") {
	C = read (fileC);
	label_counts_oob = matrix (0, rows = num_records, cols = num_labels);
}

label_counts = matrix (0, rows = num_records, cols = num_labels); 
parfor (i in 1:num_records, check = 0) {
	cur_sample = X[i,];
	cur_node_pos = 1;
	# cur_node = 1;
	cur_tree = 1;
	start_ind = 1;
	labels_found = FALSE;
	while (!labels_found) {
		
		cur_feature = as.scalar (M[4,cur_node_pos]);
		type_label = as.scalar (M[5,cur_node_pos]);
		if (cur_feature == 0) { # leaf found
			label_counts[i,type_label] = label_counts[i,type_label] + 1;
			if (fileC != " ") {
				if (as.scalar (C[i,cur_tree]) == 0) label_counts_oob[i,type_label] = label_counts_oob[i,type_label] + 1;
			}
			if (cur_tree < num_trees) {
				cur_node_pos = as.scalar (num_nodes_per_tree_cum[cur_tree,]) + 1;
			} else if (cur_tree == num_trees) {
				labels_found = TRUE;
			}
			cur_tree = cur_tree + 1;
		} else {
			# determine type: 1 for scale, 2 for categorical 
			if (type_label == 1) { # scale feature
				cur_start_ind = as.scalar (R_scale[cur_feature,]);
				cur_value = as.scalar (cur_sample[,cur_start_ind]);
				cur_split = as.scalar (M[7,cur_node_pos]);
				if (cur_value < cur_split) { # go to left branch
					cur_node_pos = cur_node_pos + as.scalar (M[3,cur_node_pos]);
					# cur_node = as.scalar (cur_M[1,cur_node_pos]);
				} else { # go to right branch
					cur_node_pos = cur_node_pos + as.scalar (M[3,cur_node_pos]) + 1;
					# cur_node = as.scalar (cur_M[1,cur_node_pos]);					
				}
			} else if (type_label == 2) { # categorical feature				
				cur_start_ind = as.scalar (R_cat[cur_feature,1]);
				cur_end_ind = as.scalar (R_cat[cur_feature,2]);					
				cur_value = as.scalar (rowIndexMax(cur_sample[,cur_start_ind:cur_end_ind])); 
				cur_offset = as.scalar (M[6,cur_node_pos]);
				value_found = sum (M[7:(7 + cur_offset - 1),cur_node_pos] == cur_value);
				if (value_found >= 1) { # go to left branch
					cur_node_pos = cur_node_pos + as.scalar (M[3,cur_node_pos]);
					# cur_node = as.scalar (cur_M[1,cur_node_pos]);
				} else { # go to right branch
					cur_node_pos = cur_node_pos + as.scalar (M[3,cur_node_pos]) + 1;
					# cur_node = as.scalar (cur_M[1,cur_node_pos]);						
				}
		
			}
}}}

Y_predicted = rowIndexMax (label_counts);
write (Y_predicted, fileP, format = fmtO);

if (fileY != " ") {
	Y_dummy = read (fileY);
	num_classes = ncol (Y_dummy);
	Y = rowSums (Y_dummy * t (seq (1, num_classes)));
	result = (Y == Y_predicted);
	result = sum (result);
	accuracy = result / num_records * 100;
	acc_str = "Accuracy (%): " + accuracy;
	if (fileA != " ") {
		write (acc_str, fileA, format = fmtO);
	} else {
		print (acc_str);
	}
	if (fileC != " ") {
		oob_ind = (rowSums (label_counts_oob) > 0)
		label_counts_oob = removeEmpty (target = label_counts_oob, margin = "rows");
		num_oob = nrow (label_counts_oob);
		Y_predicted_oob = rowIndexMax (label_counts_oob);
		Y_oob = removeEmpty (target = Y * oob_ind, margin = "rows");
		result = (Y_oob == Y_predicted_oob);
		oob_error = (1 - (sum (result) / num_oob)) * 100;
		oob_str = "Out-Of-Bag error (%): " + oob_error;
		if (fileOOB != " ") {
			write (oob_str, fileOOB, format = fmtO);
		} else {
			print (oob_str);
		}
	}
	if (fileCM != " ") {
		confusion_mat = table(Y_predicted, Y, num_classes, num_classes)
        write(confusion_mat, fileCM, format = fmtO)
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy