All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.hops.codegen.opt.PlanSelection Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.hops.codegen.opt;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;

import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.runtime.util.UtilFunctions;

public abstract class PlanSelection 
{
	private final HashMap> _bestPlans = 
			new HashMap>();
	private final HashSet _visited = new HashSet();
	
	/**
	 * Given a HOP DAG G, and a set of partial fusions plans P, find the set of optimal, 
	 * non-conflicting fusion plans P' that applied to G minimizes costs C with
	 * P' = \argmin_{p \subseteq P} C(G, p) s.t. Z \vDash p, where Z is a set of 
	 * constraints such as memory budgets and block size restrictions per fused operator.
	 * 
	 * @param memo partial fusion plans P
	 * @param roots entry points of HOP DAG G
	 */
	public abstract void selectPlans(CPlanMemoTable memo, ArrayList roots);	
	
	/**
	 * Determines if the given partial fusion plan is valid.
	 * 
	 * @param me memo table entry
	 * @param hop current hop
	 * @return true if entry is valid as top-level plan
	 */
	public static boolean isValid(MemoTableEntry me, Hop hop) {
		return (me.type != TemplateType.OUTER //ROW, CELL, MAGG
			|| (me.closed || HopRewriteUtils.isBinaryMatrixMatrixOperation(hop)));
	}
	
	protected void addBestPlan(long hopID, MemoTableEntry me) {
		if( me == null ) return;
		if( !_bestPlans.containsKey(hopID) )
			_bestPlans.put(hopID, new ArrayList());
		_bestPlans.get(hopID).add(me);
	}
	
	protected HashMap> getBestPlans() {
		return _bestPlans;
	}
	
	protected boolean isVisited(long hopID, TemplateType type) {
		return _visited.contains(new VisitMark(hopID, type));
	}
	
	protected void setVisited(long hopID, TemplateType type) {
		_visited.add(new VisitMark(hopID, type));
	}
	
	/**
	 * Basic plan comparator to compare memo table entries with regard to
	 * a pre-defined template preference order and the number of references.
	 */
	protected static class BasicPlanComparator implements Comparator {
		@Override
		public int compare(MemoTableEntry o1, MemoTableEntry o2) {
			return icompare(o1, o2);
		}
		
		public static int icompare(MemoTableEntry o1, MemoTableEntry o2) {
			if( o2 == null ) return -1;
			
			//for different types, select preferred type
			if( o1.type != o2.type )
				return Integer.compare(o1.type.getRank(), o2.type.getRank());
			
			//for same type, prefer plan with more refs
			return Integer.compare(-o1.countPlanRefs(), -o2.countPlanRefs());
		}
	}
	
	protected static class TypedPlanComparator implements Comparator {
		private TemplateType _type;
		
		public void setType(TemplateType type) {
			_type = type;
		}
		
		@Override
		public int compare(MemoTableEntry o1, MemoTableEntry o2) {
			return icompare(o1, o2, _type);
		}
		
		public static int icompare(MemoTableEntry o1, MemoTableEntry o2, TemplateType type) {
			if( o2 == null ) return -1;
			int score1 = 7 - ((o1.type==type)?4:0) - o1.countPlanRefs();
			int score2 = 7 - ((o2.type==type)?4:0) - o2.countPlanRefs();
			return Integer.compare(score1, score2);
		}
	}
	
	protected static class VisitMark {
		private final long _hopID;
		private final TemplateType _type;
		
		public VisitMark(long hopID, TemplateType type) {
			_hopID = hopID;
			_type = type;
		}
		@Override
		public int hashCode() {
			return UtilFunctions.longHashCode(
				_hopID, (_type!=null)?_type.hashCode():0);
		}
		@Override 
		public boolean equals(Object o) {
			return (o instanceof VisitMark
				&& _hopID == ((VisitMark)o)._hopID
				&& _type == ((VisitMark)o)._type);
		}
	}
	
	public static class VisitMarkCost {
		private final long _hopID;
		private final long _costID;
		
		public VisitMarkCost(long hopID, long costID) {
			_hopID = hopID;
			_costID = costID;
		}
		@Override
		public int hashCode() {
			return UtilFunctions.longHashCode(
				_hopID, _costID);
		}
		@Override 
		public boolean equals(Object o) {
			return (o instanceof VisitMarkCost
				&& _hopID == ((VisitMarkCost)o)._hopID
				&& _costID == ((VisitMarkCost)o)._costID);
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy