scripts.algorithms.m-svm.dml Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# Implements multiclass SVM with squared slack variables,
# learns one-against-the-rest binary-class classifiers
#
# Example Usage:
# Assume SVM_HOME is set to the home of the dml script
# Assume input and output directories are on hdfs as INPUT_DIR and OUTPUT_DIR
# Assume epsilon = 0.001, lambda=1.0, maxiterations = 100
#
# INPUT PARAMETERS:
# ---------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ---------------------------------------------------------------------------------------------
# X String --- Location to read the matrix X of feature vectors
# Y String --- Location to read response matrix Y
# icpt Int 0 Intercept presence
# 0 = no intercept
# 1 = add intercept;
# tol Double 0.001 Tolerance (epsilon);
# reg Double 1.0 Regularization parameter
# maxiter Int 100 Maximum number of conjugate gradient iterations
# model String --- Location to write model
# fmt String "text" The output format of the output, such as "text" or "csv"
# Log String --- [OPTIONAL] Location to write the log file
# ---------------------------------------------------------------------------------------------
#
# hadoop jar SystemML.jar -f $SVM_HOME/m-svm.dml -nvargs X=$INPUT_DIR/X Y=$INPUT_DIR/y icpt=intercept tol=.001 reg=1.0 maxiter=100 model=$OUTPUT_DIR/w Log=$OUTPUT_DIR/Log fmt="text"
#
cmdLine_fmt = ifdef($fmt, "text")
cmdLine_icpt = ifdef($icpt, 0)
cmdLine_tol = ifdef($tol, 0.001)
cmdLine_reg = ifdef($reg, 1.0)
cmdLine_maxiter = ifdef($maxiter, 100)
print("icpt=" + cmdLine_icpt + " tol=" + cmdLine_tol + " reg=" + cmdLine_reg + " maxiter=" + cmdLine_maxiter)
X = read($X)
if(nrow(X) < 2)
stop("Stopping due to invalid inputs: Not possible to learn a classifier without at least 2 rows")
dimensions = ncol(X)
Y = read($Y)
if(nrow(X) != nrow(Y))
stop("Stopping due to invalid argument: Numbers of rows in X and Y must match")
intercept = cmdLine_icpt
if(intercept != 0 & intercept != 1)
stop("Stopping due to invalid argument: Currently supported intercept options are 0 and 1")
min_y = min(Y)
if(min_y < 1)
stop("Stopping due to invalid argument: Label vector (Y) must be recoded")
num_classes = max(Y)
if(num_classes == 1)
stop("Stopping due to invalid argument: Maximum label value is 1, need more than one class to learn a multi-class classifier")
mod1 = Y %% 1
mod1_should_be_nrow = sum(abs(mod1 == 0))
if(mod1_should_be_nrow != nrow(Y))
stop("Stopping due to invalid argument: Please ensure that Y contains (positive) integral labels")
epsilon = cmdLine_tol
if(epsilon < 0)
stop("Stopping due to invalid argument: Tolerance (tol) must be non-negative")
lambda = cmdLine_reg
if(lambda < 0)
stop("Stopping due to invalid argument: Regularization constant (reg) must be non-negative")
maxiterations = cmdLine_maxiter
if(maxiterations < 1)
stop("Stopping due to invalid argument: Maximum iterations should be a positive integer")
num_samples = nrow(X)
num_features = ncol(X)
if (intercept == 1) {
ones = matrix(1, rows=num_samples, cols=1);
X = cbind(X, ones);
}
num_rows_in_w = num_features
if(intercept == 1){
num_rows_in_w = num_rows_in_w + 1
}
w = matrix(0, rows=num_rows_in_w, cols=num_classes)
debug_mat = matrix(-1, rows=maxiterations, cols=num_classes)
parfor(iter_class in 1:num_classes){
Y_local = 2 * (Y == iter_class) - 1
w_class = matrix(0, rows=num_features, cols=1)
if (intercept == 1) {
zero_matrix = matrix(0, rows=1, cols=1);
w_class = t(cbind(t(w_class), zero_matrix));
}
g_old = t(X) %*% Y_local
s = g_old
Xw = matrix(0, rows=nrow(X), cols=1)
iter = 0
continue = TRUE
while(continue & iter < maxiterations) {
# minimizing primal obj along direction s
step_sz = 0
Xd = X %*% s
wd = lambda * sum(w_class * s)
dd = lambda * sum(s * s)
continue1 = TRUE
while(continue1){
tmp_Xw = Xw + step_sz*Xd
out = 1 - Y_local * (tmp_Xw)
sv = (out > 0)
out = out * sv
g = wd + step_sz*dd - sum(out * Y_local * Xd)
h = dd + sum(Xd * sv * Xd)
step_sz = step_sz - g/h
continue1 = (g*g/h >= 0.0000000001)
}
#update weights
w_class = w_class + step_sz*s
Xw = Xw + step_sz*Xd
out = 1 - Y_local * Xw
sv = (out > 0)
out = sv * out
obj = 0.5 * sum(out * out) + lambda/2 * sum(w_class * w_class)
g_new = t(X) %*% (out * Y_local) - lambda * w_class
tmp = sum(s * g_old)
train_acc = sum(Y_local*(X%*%w_class) >= 0)/num_samples*100
print("For class " + iter_class + " iteration " + iter + " training accuracy: " + train_acc)
debug_mat[iter+1,iter_class] = obj
#non-linear CG step
be = sum(g_new * g_new)/sum(g_old * g_old)
s = be * s + g_new
g_old = g_new
continue = (step_sz*tmp >= epsilon*obj & sum(s^2) != 0);
iter = iter + 1
}
w[,iter_class] = w_class
} # parfor loop
extra_model_params = matrix(0, rows=2, cols=ncol(w))
extra_model_params[1, 1] = intercept
extra_model_params[2, 1] = dimensions
w = t(cbind(t(w), t(extra_model_params)))
write(w, $model, format=cmdLine_fmt)
debug_str = "# Class, Iter, Obj"
for(iter_class in 1:ncol(debug_mat)){
for(iter in 1:nrow(debug_mat)){
obj = as.scalar(debug_mat[iter, iter_class])
if(obj != -1)
debug_str = append(debug_str, iter_class + "," + iter + "," + obj)
}
}
logFile = $Log
if(logFile != " ")
write(debug_str, logFile)