scripts.algorithms.decision-tree-predict.dml Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
#
# THIS SCRIPT COMPUTES LABEL PREDICTIONS MEANT FOR USE WITH A DECISION TREE MODEL ON A HELD OUT TEST SET.
#
# INPUT PARAMETERS:
# ---------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ---------------------------------------------------------------------------------------------
# X String --- Location to read the test feature matrix X; note that X needs to be both recoded and dummy coded
# Y String " " Location to read the true label matrix Y if requested; note that Y needs to be both recoded and dummy coded
# R String " " Location to read matrix R which for each feature in X contains the following information
# - R[,1]: column ids
# - R[,2]: start indices
# - R[,3]: end indices
# If R is not provided by default all variables are assumed to be scale
# M String --- Location to read matrix M containing the learned tree in the following format
# - M[1,j]: id of node j (in a complete binary tree)
# - M[2,j]: Offset (no. of columns) to left child of j if j is an internal node, otherwise 0
# - M[3,j]: Feature index of the feature that node j looks at if j is an internal node, otherwise 0
# - M[4,j]: Type of the feature that node j looks at if j is an internal node: 1 for scale and 2 for categorical features,
# otherwise the label that leaf node j is supposed to predict
# - M[5,j]: If j is an internal node: 1 if the feature chosen for j is scale, otherwise the size of the subset of values
# stored in rows 6,7,... if j is categorical
# If j is a leaf node: number of misclassified samples reaching at node j
# - M[6:,j]: If j is an internal node: Threshold the example's feature value is compared to is stored at M[6,j]
# if the feature chosen for j is scale, otherwise if the feature chosen for j is categorical rows 6,7,...
# depict the value subset chosen for j
# If j is a leaf node 1 if j is impure and the number of samples at j > threshold, otherwise 0
# P String --- Location to store the label predictions for X
# A String " " Location to write the test accuracy (%) for the prediction if requested
# CM String " " Location to write the confusion matrix if requested
# fmt String "text" The output format of the output, such as "text" or "csv"
# ---------------------------------------------------------------------------------------------
# OUTPUT:
# 1- Matrix Y containing the predicted labels for X
# 2- Test accuracy if requested
# 3- Confusion matrix C if requested
# -------------------------------------------------------------------------------------------
# HOW TO INVOKE THIS SCRIPT - EXAMPLE:
# hadoop jar SystemML.jar -f decision-tree-predict.dml -nvargs X=INPUT_DIR/X Y=INPUT_DIR/Y R=INPUT_DIR/R M=INPUT_DIR/model P=OUTPUT_DIR/predictions
# A=OUTPUT_DIR/accuracy CM=OUTPUT_DIR/confusion fmt=csv
fileX = $X;
fileM = $M;
fileP = $P;
fileY = ifdef ($Y, " ");
fileR = ifdef ($R, " ");
fileCM = ifdef ($CM, " ");
fileA = ifdef ($A, " ");
fmtO = ifdef ($fmt, "text");
X_test = read (fileX);
M = read (fileM);
num_records = nrow (X_test);
Y_predicted = matrix (0, rows = num_records, cols = 1);
R_cat = matrix (0, rows = 1, cols = 1);
R_scale = matrix (0, rows = 1, cols = 1);
if (fileR != " ") {
R = read (fileR);
dummy_coded = (R[,2] != R[,3]);
R_scale = removeEmpty (target = R[,2] * (1 - dummy_coded), margin = "rows");
R_cat = removeEmpty (target = R[,2:3] * dummy_coded, margin = "rows");
} else { # only scale features available
R_scale = seq (1, ncol (X_test));
}
parfor (i in 1:num_records, check = 0) {
cur_sample = X_test[i,];
cur_node_pos = 1;
label_found = FALSE;
while (!label_found) {
cur_feature = as.scalar (M[3,cur_node_pos]);
type_label = as.scalar (M[4,cur_node_pos]);
if (cur_feature == 0) { # leaf node
label_found = TRUE;
Y_predicted[i,] = type_label;
} else {
# determine type: 1 for scale, 2 for categorical
if (type_label == 1) { # scale feature
cur_start_ind = as.scalar (R_scale[cur_feature,]);
cur_value = as.scalar (cur_sample[,cur_start_ind]);
cur_split = as.scalar (M[6,cur_node_pos]);
if (cur_value < cur_split) { # go to left branch
cur_node_pos = cur_node_pos + as.scalar (M[2,cur_node_pos]);
} else { # go to right branch
cur_node_pos = cur_node_pos + as.scalar (M[2,cur_node_pos]) + 1;
}
} else if (type_label == 2) { # categorical feature
cur_start_ind = as.scalar (R_cat[cur_feature,1]);
cur_end_ind = as.scalar (R_cat[cur_feature,2]);
cur_value = as.scalar (rowIndexMax(cur_sample[,cur_start_ind:cur_end_ind])); # as.scalar (cur_sample[,cur_feature]);
cur_offset = as.scalar (M[5,cur_node_pos]);
value_found = sum (M[6:(6 + cur_offset - 1),cur_node_pos] == cur_value);
if (value_found) { # go to left branch
cur_node_pos = cur_node_pos + as.scalar (M[2,cur_node_pos]);
} else { # go to right branch
cur_node_pos = cur_node_pos + as.scalar (M[2,cur_node_pos]) + 1;
}
}}}}
write (Y_predicted, fileP, format = fmtO);
if (fileY != " ") {
Y_test = read (fileY);
num_classes = ncol (Y_test);
Y_test = rowSums (Y_test * t (seq (1, num_classes)));
result = (Y_test == Y_predicted);
result = sum (result);
accuracy = result / num_records * 100;
acc_str = "Accuracy (%): " + accuracy;
if (fileA != " ") {
write (acc_str, fileA, format = fmtO);
} else {
print (acc_str);
}
if (fileCM != " ") {
confusion_mat = table(Y_predicted, Y_test, num_classes, num_classes)
write(confusion_mat, fileCM, format = fmtO)
}
}