org.apache.sysml.runtime.transform.RecodeAgent Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.transform;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.wink.json4j.JSONArray;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;
import scala.Tuple2;
import com.google.common.collect.Ordering;
import org.apache.sysml.runtime.transform.MVImputeAgent.MVMethod;
import org.apache.sysml.runtime.util.UtilFunctions;
public class RecodeAgent extends TransformationAgent {
private static final long serialVersionUID = 8213163881283341874L;
private int[] _rcdList = null;
private int[] _mvrcdList = null;
private int[] _fullrcdList = null;
// HashMap< columnID, HashMap >
private HashMap> _rcdMaps = new HashMap>();
RecodeAgent(JSONObject parsedSpec) throws JSONException {
int rcdCount = 0;
if ( parsedSpec.containsKey(TX_METHOD.RECODE.toString()))
{
JSONObject obj = (JSONObject) parsedSpec.get(TX_METHOD.RECODE.toString());
JSONArray attrs = (JSONArray) obj.get(JSON_ATTRS);
_rcdList = new int[attrs.size()];
for(int i=0; i < _rcdList.length; i++)
_rcdList[i] = UtilFunctions.toInt(attrs.get(i));
rcdCount = _rcdList.length;
}
if ( parsedSpec.containsKey(TX_METHOD.MVRCD.toString()))
{
JSONObject obj = (JSONObject) parsedSpec.get(TX_METHOD.MVRCD.toString());
JSONArray attrs = (JSONArray) obj.get(JSON_ATTRS);
_mvrcdList = new int[attrs.size()];
for(int i=0; i < _mvrcdList.length; i++)
_mvrcdList[i] = UtilFunctions.toInt(attrs.get(i));
rcdCount += attrs.size();
}
if ( rcdCount > 0 )
{
_fullrcdList = new int[rcdCount];
int idx = -1;
if(_rcdList != null)
for(int i=0; i < _rcdList.length; i++)
_fullrcdList[++idx] = _rcdList[i];
if(_mvrcdList != null)
for(int i=0; i < _mvrcdList.length; i++)
_fullrcdList[++idx] = _mvrcdList[i];
}
}
void prepare(String[] words, TfUtils agents) {
if ( _rcdList == null && _mvrcdList == null )
return;
String w = null;
for (int colID : _fullrcdList) {
w = UtilFunctions.unquote(words[colID-1].trim());
if(_rcdMaps.get(colID) == null )
_rcdMaps.put(colID, new HashMap());
HashMap map = _rcdMaps.get(colID);
Long count = map.get(w);
if(count == null)
map.put(w, new Long(1));
else
map.put(w, count+1);
}
}
private HashMap handleMVConstant(int colID, TfUtils agents, HashMap map)
{
MVImputeAgent mvagent = agents.getMVImputeAgent();
if ( mvagent.getMethod(colID) == MVMethod.CONSTANT )
{
// check if the "replacement" is part of the map. If not, add it.
String repValue = mvagent.getReplacement(colID);
if(repValue == null)
throw new RuntimeException("Expecting a constant replacement value for column ID " + colID);
repValue = UtilFunctions.unquote(repValue);
Long count = map.get(repValue);
long mvCount = agents.getValid() - mvagent.getNonMVCount(colID);
if(count == null)
map.put(repValue, mvCount);
else
map.put(repValue, count + mvCount);
}
return map;
}
/**
* Method to output transformation metadata from the mappers.
* This information is collected and merged by the reducers.
*
* @param out
* @throws IOException
*/
@Override
public void mapOutputTransformationMetadata(OutputCollector out, int taskID, TfUtils agents) throws IOException {
mapOutputHelper(taskID, out, null, agents);
}
public ArrayList> mapOutputTransformationMetadata(int taskID, ArrayList> list, TfUtils agents) throws IOException {
mapOutputHelper(taskID, null, list, agents);
return list;
}
public void mapOutputHelper(int taskID, OutputCollector out, ArrayList> list, TfUtils agents) throws IOException {
if ( _rcdList == null && _mvrcdList == null )
return;
try
{
for(int i=0; i < _fullrcdList.length; i++)
{
int colID = _fullrcdList[i];
HashMap map = _rcdMaps.get(colID);
if(map != null)
{
map = handleMVConstant(colID, agents, map);
if ( out != null ) {
IntWritable iw = new IntWritable(colID);
for(String s : map.keySet())
out.collect(iw, new DistinctValue(s, map.get(s)));
}
else if ( list != null ) {
for(String s : map.keySet())
list.add(new Tuple2(colID, new DistinctValue(s, map.get(s))) );
}
}
}
} catch(Exception e) {
throw new IOException(e);
}
}
/**
* Function to output transformation metadata, including:
* - recode maps,
* - number of distinct values,
* - mode, and
* - imputation value (in the case of global_mode)
*
* The column for which this function is invoked can be one of the following:
* - just recoded (write .map, .ndistinct, .mode)
* - just mv imputed (w/ global_mode) (write .impute)
* - both recoded and mv imputed (write .map, .ndistinct, .mode, .impute)
*
* @param map
* @param outputDir
* @param colID
* @param fs
* @param mvagent
* @throws IOException
*/
private void writeMetadata(HashMap map, String outputDir, int colID, FileSystem fs, TfUtils agents, boolean fromCP) throws IOException {
// output recode maps and mode
MVImputeAgent mvagent = agents.getMVImputeAgent();
String mode = null;
Long count = null;
int rcdIndex = 0, modeIndex = 0;
long maxCount = Long.MIN_VALUE;
boolean isRecoded = (isRecoded(colID) != -1);
boolean isModeImputed = (mvagent.getMethod(colID) == MVMethod.GLOBAL_MODE);
Path pt=new Path(outputDir+"/Recode/"+ agents.getName(colID) + RCD_MAP_FILE_SUFFIX);
BufferedWriter br=null;
if(isRecoded)
br = new BufferedWriter(new OutputStreamWriter(fs.create(pt,true)));
// remove NA strings
if ( agents.getNAStrings() != null)
for(String naword : agents.getNAStrings())
map.remove(naword);
if(fromCP)
map = handleMVConstant(colID, agents, map);
if ( map.size() == 0 )
throw new RuntimeException("Can not proceed since \"" + agents.getName(colID) + "\" (id=" + colID + ") contains only the missing values, and not a single valid value -- set imputation method to \"constant\".");
// Order entries by category (string) value
Ordering valueComparator = Ordering.natural();
List newNames = valueComparator.sortedCopy(map.keySet());
for(String w : newNames) { //map.keySet()) {
count = map.get(w);
++rcdIndex;
// output (w, count, rcdIndex)
if(br != null)
br.write(UtilFunctions.quote(w) + TXMTD_SEP + rcdIndex + TXMTD_SEP + count + "\n");
if(maxCount < count) {
maxCount = count;
mode = w;
modeIndex = rcdIndex;
}
// Replace count with recode index (useful when invoked from CP)
map.put(w, (long)rcdIndex);
}
if(br != null)
br.close();
if ( mode == null ) {
mode = "";
maxCount = 0;
}
if ( isRecoded )
{
// output mode
pt=new Path(outputDir+"/Recode/"+ agents.getName(colID) + MODE_FILE_SUFFIX);
br=new BufferedWriter(new OutputStreamWriter(fs.create(pt,true)));
br.write(UtilFunctions.quote(mode) + "," + modeIndex + "," + maxCount );
br.close();
// output number of distinct values
pt=new Path(outputDir+"/Recode/"+ agents.getName(colID) + NDISTINCT_FILE_SUFFIX);
br=new BufferedWriter(new OutputStreamWriter(fs.create(pt,true)));
br.write(""+map.size());
br.close();
}
if (isModeImputed)
{
pt=new Path(outputDir+"/Impute/"+ agents.getName(colID) + MV_FILE_SUFFIX);
br=new BufferedWriter(new OutputStreamWriter(fs.create(pt,true)));
br.write(colID + "," + UtilFunctions.quote(mode));
br.close();
}
}
public void outputTransformationMetadata(String outputDir, FileSystem fs, TfUtils agents) throws IOException {
if(_rcdList == null && _mvrcdList == null )
return;
for(int i=0; i<_fullrcdList.length; i++) {
int colID = _fullrcdList[i];
writeMetadata(_rcdMaps.get(colID), outputDir, colID, fs, agents, true);
}
}
/**
* Method to merge map output transformation metadata.
*
* @param values
* @return
* @throws IOException
*/
@Override
public void mergeAndOutputTransformationMetadata(Iterator values, String outputDir, int colID, FileSystem fs, TfUtils agents) throws IOException {
HashMap map = new HashMap();
DistinctValue d = new DistinctValue();
String word = null;
Long count = null, val = null;
while(values.hasNext()) {
d.reset();
d = values.next();
word = d.getWord();
count = d.getCount();
val = map.get(word);
if(val == null)
map.put(word, count);
else
map.put(word, val+count);
}
writeMetadata(map, outputDir, colID, fs, agents, false);
}
// ------------------------------------------------------------------------------------------------
public HashMap> getCPRecodeMaps() { return _rcdMaps; }
HashMap> _finalMaps = null;
public HashMap> getRecodeMaps() {
return _finalMaps;
}
/**
* Method to load recode maps of all attributes, at once.
*
* @param job
* @throws IOException
*/
@Override
public void loadTxMtd(JobConf job, FileSystem fs, Path txMtdDir, TfUtils agents) throws IOException {
if ( _rcdList == null )
return;
_finalMaps = new HashMap>();
if(fs.isDirectory(txMtdDir)) {
for(int i=0; i<_rcdList.length;i++) {
int colID = _rcdList[i];
Path path = new Path( txMtdDir + "/Recode/" + agents.getName(colID) + RCD_MAP_FILE_SUFFIX);
TfUtils.checkValidInputFile(fs, path, true);
HashMap map = new HashMap();
BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)));
String line = null, word=null;
String rcdIndex = null;
// Example line to parse: "WN (1)67492",1,61975
while((line=br.readLine())!=null) {
// last occurrence of quotation mark
int idxQuote = line.lastIndexOf('"');
word = UtilFunctions.unquote(line.substring(0,idxQuote+1));
int idx = idxQuote+2;
while(line.charAt(idx) != TXMTD_SEP.charAt(0))
idx++;
rcdIndex = line.substring(idxQuote+2,idx);
map.put(word, rcdIndex);
}
br.close();
_finalMaps.put(colID, map);
}
}
else {
fs.close();
throw new RuntimeException("Path to recode maps must be a directory: " + txMtdDir);
}
}
/**
* Method to apply transformations.
*
* @param words
* @return
*/
@Override
public String[] apply(String[] words, TfUtils agents) {
if ( _rcdList == null )
return words;
for(int i=0; i < _rcdList.length; i++) {
int colID = _rcdList[i];
try {
words[colID-1] = _finalMaps.get(colID).get(UtilFunctions.unquote(words[colID-1].trim()));
} catch(NullPointerException e) {
System.err.println("Maps for colID="+colID + " may be null (map = " + _finalMaps.get(colID) + ")");
throw new RuntimeException(e);
}
}
return words;
}
/**
* Check if the given column ID is subjected to this transformation.
*
*/
public int isRecoded(int colID)
{
if(_rcdList == null)
return -1;
int idx = Arrays.binarySearch(_rcdList, colID);
return ( idx >= 0 ? idx : -1);
}
public String[] cp_apply(String[] words, TfUtils agents) {
if ( _rcdList == null )
return words;
String w = null;
for(int i=0; i < _rcdList.length; i++) {
int colID = _rcdList[i];
try {
w = UtilFunctions.unquote(words[colID-1].trim());
words[colID-1] = Long.toString(_rcdMaps.get(colID).get(w));
} catch(NullPointerException e) {
if(w.isEmpty() && agents.isNA("") )
throw new RuntimeException("Empty string (a missing value) in column ID " + colID + " is not handled. Consider adding an imputation method on this column.");
throw new RuntimeException("ColID="+colID + ", word=" + words[colID-1] + ", maps entry not found (map = " + _rcdMaps.get(colID) + ")");
}
}
return words;
}
public void printMaps() {
for(Integer k : _rcdMaps.keySet()) {
System.out.println("Column " + k);
HashMap map = _rcdMaps.get(k);
for(String w : map.keySet()) {
System.out.println(" " + w + " : " + map.get(w));
}
}
}
public void print() {
System.out.print("Recoding List: \n ");
for(int i : _rcdList) {
System.out.print(i + " ");
}
System.out.println();
}
}