All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.transform.RecodeAgent Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.transform;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;

import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.wink.json4j.JSONArray;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

import scala.Tuple2;

import com.google.common.collect.Ordering;
import org.apache.sysml.runtime.transform.MVImputeAgent.MVMethod;
import org.apache.sysml.runtime.util.UtilFunctions;

public class RecodeAgent extends TransformationAgent {
	
	private static final long serialVersionUID = 8213163881283341874L;

	private int[] _rcdList = null;
	private int[] _mvrcdList = null;
	private int[] _fullrcdList = null;

	// HashMap< columnID, HashMap >
	private HashMap> _rcdMaps  = new HashMap>();
	
	RecodeAgent(JSONObject parsedSpec) throws JSONException {
		
		int rcdCount = 0;
		
		if ( parsedSpec.containsKey(TX_METHOD.RECODE.toString())) 
		{
			JSONObject obj = (JSONObject) parsedSpec.get(TX_METHOD.RECODE.toString());
			JSONArray attrs = (JSONArray) obj.get(JSON_ATTRS);
			
			_rcdList = new int[attrs.size()];
			for(int i=0; i < _rcdList.length; i++) 
				_rcdList[i] = UtilFunctions.toInt(attrs.get(i));
			rcdCount = _rcdList.length;
		}
		
		if ( parsedSpec.containsKey(TX_METHOD.MVRCD.toString())) 
		{
			JSONObject obj = (JSONObject) parsedSpec.get(TX_METHOD.MVRCD.toString());
			JSONArray attrs = (JSONArray) obj.get(JSON_ATTRS);
			
			_mvrcdList = new int[attrs.size()];
			for(int i=0; i < _mvrcdList.length; i++) 
				_mvrcdList[i] = UtilFunctions.toInt(attrs.get(i));
			rcdCount += attrs.size();
		}
		
		if ( rcdCount > 0 )
		{
			_fullrcdList = new int[rcdCount];
			int idx = -1;
			if(_rcdList != null)
				for(int i=0; i < _rcdList.length; i++)
					_fullrcdList[++idx] = _rcdList[i]; 
			
			if(_mvrcdList != null)
				for(int i=0; i < _mvrcdList.length; i++)
					_fullrcdList[++idx] = _mvrcdList[i]; 
		}
	}
	
	void prepare(String[] words, TfUtils agents) {
		if ( _rcdList == null && _mvrcdList == null )
			return;
		
		String w = null;
		for (int colID : _fullrcdList) {
			w = UtilFunctions.unquote(words[colID-1].trim());
			if(_rcdMaps.get(colID) == null ) 
				_rcdMaps.put(colID, new HashMap());
			
			HashMap map = _rcdMaps.get(colID);
			Long count = map.get(w);
			if(count == null)
				map.put(w, new Long(1));
			else
				map.put(w, count+1);
		}
	}
	
	private HashMap handleMVConstant(int colID, TfUtils agents, HashMap map)
	{
		MVImputeAgent mvagent = agents.getMVImputeAgent();
		if ( mvagent.getMethod(colID) == MVMethod.CONSTANT ) 
		{
			// check if the "replacement" is part of the map. If not, add it.
			String repValue = mvagent.getReplacement(colID);
			if(repValue == null)
				throw new RuntimeException("Expecting a constant replacement value for column ID " + colID);
			
			repValue = UtilFunctions.unquote(repValue);
			Long count = map.get(repValue);
			long mvCount = agents.getValid() - mvagent.getNonMVCount(colID);
			if(count == null)
				map.put(repValue, mvCount);
			else
				map.put(repValue, count + mvCount);
		}
		return map;
	}
	
	/**
	 * Method to output transformation metadata from the mappers. 
	 * This information is collected and merged by the reducers.
	 * 
	 * @param out
	 * @throws IOException
	 */
	@Override
	public void mapOutputTransformationMetadata(OutputCollector out, int taskID, TfUtils agents) throws IOException {
		mapOutputHelper(taskID, out, null, agents);
	}
	
	public ArrayList> mapOutputTransformationMetadata(int taskID, ArrayList> list, TfUtils agents) throws IOException {
		mapOutputHelper(taskID, null, list, agents);
		return list;
	}
	
	public void mapOutputHelper(int taskID, OutputCollector out, ArrayList> list, TfUtils agents) throws IOException {
		if ( _rcdList == null  && _mvrcdList == null )
			return;
		
		try 
		{ 
			for(int i=0; i < _fullrcdList.length; i++) 
			{
				int colID = _fullrcdList[i];
				HashMap map = _rcdMaps.get(colID);
				
				if(map != null) 
				{
					map = handleMVConstant(colID, agents,  map);
					
					if ( out != null ) {
						IntWritable iw = new IntWritable(colID);
						for(String s : map.keySet()) 
							out.collect(iw, new DistinctValue(s, map.get(s)));
					}
					else if ( list != null ) {
						for(String s : map.keySet()) 
							list.add(new Tuple2(colID, new DistinctValue(s, map.get(s))) );
					}
				}
			}
		} catch(Exception e) {
			throw new IOException(e);
		}
	}
	
	/**
	 * Function to output transformation metadata, including: 
	 * - recode maps, 
	 * - number of distinct values, 
	 * - mode, and 
	 * - imputation value (in the case of global_mode)
	 * 
	 * The column for which this function is invoked can be one of the following:
	 * - just recoded						(write .map, .ndistinct, .mode)
	 * - just mv imputed (w/ global_mode)	(write .impute)
	 * - both recoded and mv imputed		(write .map, .ndistinct, .mode, .impute)
	 * 
	 * @param map
	 * @param outputDir
	 * @param colID
	 * @param fs
	 * @param mvagent
	 * @throws IOException
	 */
	private void writeMetadata(HashMap map, String outputDir, int colID, FileSystem fs, TfUtils agents, boolean fromCP) throws IOException {
		// output recode maps and mode
		
		MVImputeAgent mvagent = agents.getMVImputeAgent();
		String mode = null;
		Long count = null;
		int rcdIndex = 0, modeIndex = 0;
		long maxCount = Long.MIN_VALUE;
		
		boolean isRecoded = (isRecoded(colID) != -1);
		boolean isModeImputed = (mvagent.getMethod(colID) == MVMethod.GLOBAL_MODE);
		
		Path pt=new Path(outputDir+"/Recode/"+ agents.getName(colID) + RCD_MAP_FILE_SUFFIX);
		BufferedWriter br=null;
		if(isRecoded)
			br = new BufferedWriter(new OutputStreamWriter(fs.create(pt,true)));		

		// remove NA strings
		if ( agents.getNAStrings() != null)
			for(String naword : agents.getNAStrings()) 
				map.remove(naword);
		
		if(fromCP)
			map = handleMVConstant(colID, agents,  map);
		
		if ( map.size() == 0 ) 
			throw new RuntimeException("Can not proceed since \"" + agents.getName(colID) + "\" (id=" + colID + ") contains only the missing values, and not a single valid value -- set imputation method to \"constant\".");
		
		// Order entries by category (string) value
		Ordering valueComparator = Ordering.natural();
		List newNames = valueComparator.sortedCopy(map.keySet());

		for(String w : newNames) { //map.keySet()) {
				count = map.get(w);
				++rcdIndex;
				
				// output (w, count, rcdIndex)
				if(br != null)		
					br.write(UtilFunctions.quote(w) + TXMTD_SEP + rcdIndex + TXMTD_SEP + count  + "\n");
				
				if(maxCount < count) {
					maxCount = count;
					mode = w;
					modeIndex = rcdIndex;
				}
				
				// Replace count with recode index (useful when invoked from CP)
				map.put(w, (long)rcdIndex);
		}
		
		if(br != null)		
			br.close();
		
		if ( mode == null ) {
			mode = "";
			maxCount = 0;
		}
		
		if ( isRecoded ) 
		{
			// output mode
			pt=new Path(outputDir+"/Recode/"+ agents.getName(colID) + MODE_FILE_SUFFIX);
			br=new BufferedWriter(new OutputStreamWriter(fs.create(pt,true)));
			br.write(UtilFunctions.quote(mode) + "," + modeIndex + "," + maxCount );
			br.close();
		
			// output number of distinct values
			pt=new Path(outputDir+"/Recode/"+ agents.getName(colID) + NDISTINCT_FILE_SUFFIX);
			br=new BufferedWriter(new OutputStreamWriter(fs.create(pt,true)));
			br.write(""+map.size());
			br.close();
		}
		
		if (isModeImputed) 
		{
			pt=new Path(outputDir+"/Impute/"+ agents.getName(colID) + MV_FILE_SUFFIX);
			br=new BufferedWriter(new OutputStreamWriter(fs.create(pt,true)));
			br.write(colID + "," + UtilFunctions.quote(mode));
			br.close();
		}
		
	}
	
	public void outputTransformationMetadata(String outputDir, FileSystem fs, TfUtils agents) throws IOException {
		if(_rcdList == null && _mvrcdList == null )
			return;
		
		for(int i=0; i<_fullrcdList.length; i++) {
			int colID = _fullrcdList[i];
			writeMetadata(_rcdMaps.get(colID), outputDir, colID, fs, agents, true);
		}
	}
	
	/** 
	 * Method to merge map output transformation metadata.
	 * 
	 * @param values
	 * @return
	 * @throws IOException 
	 */
	@Override
	public void mergeAndOutputTransformationMetadata(Iterator values, String outputDir, int colID, FileSystem fs, TfUtils agents) throws IOException {
		HashMap map = new HashMap();
		
		DistinctValue d = new DistinctValue();
		String word = null;
		Long count = null, val = null;
		while(values.hasNext()) {
			d.reset();
			d = values.next();
			
			word = d.getWord();
			count = d.getCount();
			
			val = map.get(word);
			if(val == null) 
				map.put(word, count);
			else 
				map.put(word, val+count);
		}
		
		writeMetadata(map, outputDir, colID, fs, agents, false);
	}
	
	// ------------------------------------------------------------------------------------------------
	
	public HashMap> getCPRecodeMaps() { return _rcdMaps; }
	
	HashMap> _finalMaps = null;
	public HashMap> getRecodeMaps() {
		return _finalMaps;
	}
	
	/**
	 * Method to load recode maps of all attributes, at once.
	 * 
	 * @param job
	 * @throws IOException
	 */
	@Override
	public void loadTxMtd(JobConf job, FileSystem fs, Path txMtdDir, TfUtils agents) throws IOException {
		if ( _rcdList == null )
			return;
		
		_finalMaps = new HashMap>();
	
		if(fs.isDirectory(txMtdDir)) {
			for(int i=0; i<_rcdList.length;i++) {
				int colID = _rcdList[i];
				
				Path path = new Path( txMtdDir + "/Recode/" + agents.getName(colID) + RCD_MAP_FILE_SUFFIX);
				TfUtils.checkValidInputFile(fs, path, true); 
				
				HashMap map = new HashMap();
				
				BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)));
				String line = null, word=null;
				String rcdIndex = null;
				
				// Example line to parse: "WN (1)67492",1,61975
				while((line=br.readLine())!=null) {
					
					// last occurrence of quotation mark
					int idxQuote = line.lastIndexOf('"');
					word = UtilFunctions.unquote(line.substring(0,idxQuote+1));
					
					int idx = idxQuote+2;
					while(line.charAt(idx) != TXMTD_SEP.charAt(0))
						idx++;
					rcdIndex = line.substring(idxQuote+2,idx); 
					
					map.put(word, rcdIndex);
				}
				br.close();
				_finalMaps.put(colID, map);
			}
		}
		else {
			fs.close();
			throw new RuntimeException("Path to recode maps must be a directory: " + txMtdDir);
		}
	}
	
	/**
	 * Method to apply transformations.
	 * 
	 * @param words
	 * @return
	 */
	@Override
	public String[] apply(String[] words, TfUtils agents) {
		if ( _rcdList == null )
			return words;
		
		for(int i=0; i < _rcdList.length; i++) {
			int colID = _rcdList[i];
			try {
				words[colID-1] = _finalMaps.get(colID).get(UtilFunctions.unquote(words[colID-1].trim()));
			} catch(NullPointerException e) {
				System.err.println("Maps for colID="+colID + " may be null (map = " + _finalMaps.get(colID) + ")");
				throw new RuntimeException(e);
			}
		}
			
		return words;
	}
	
	/**
	 * Check if the given column ID is subjected to this transformation.
	 * 
	 */
	public int isRecoded(int colID)
	{
		if(_rcdList == null)
			return -1;
		
		int idx = Arrays.binarySearch(_rcdList, colID);
		return ( idx >= 0 ? idx : -1);
	}

	public String[] cp_apply(String[] words, TfUtils agents) {
		if ( _rcdList == null )
			return words;
		
		String w = null;
		for(int i=0; i < _rcdList.length; i++) {
			int colID = _rcdList[i];
			try {
				w = UtilFunctions.unquote(words[colID-1].trim());
				words[colID-1] = Long.toString(_rcdMaps.get(colID).get(w));
			} catch(NullPointerException e) {
				if(w.isEmpty() && agents.isNA("") )
					throw new RuntimeException("Empty string (a missing value) in column ID " + colID + " is not handled. Consider adding an imputation method on this column.");		
				throw new RuntimeException("ColID="+colID + ", word=" + words[colID-1] + ", maps entry not found (map = " + _rcdMaps.get(colID) + ")");
			}
		}
			
		return words;
	}
	
	
	public void printMaps() {
		for(Integer k : _rcdMaps.keySet()) {
			System.out.println("Column " + k);
			HashMap map = _rcdMaps.get(k);
			for(String w : map.keySet()) {
				System.out.println("    " + w + " : " + map.get(w));
			}
		}
	}
	
	public void print() {
		System.out.print("Recoding List: \n    ");
		for(int i : _rcdList) {
			System.out.print(i + " ");
		}
		System.out.println();
	}
}
 




© 2015 - 2024 Weber Informatics LLC | Privacy Policy