All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.insight.core.action.ActionHandler Maven / Gradle / Ivy

The newest version!
package com.datastax.insight.core.action;

import com.datastax.insight.core.Consts;
import com.datastax.insight.core.conf.Component;
import com.datastax.insight.core.conf.Constants;
import com.datastax.insight.core.entity.Context;
import com.datastax.insight.core.service.PersistService;
import com.datastax.insight.core.dag.*;
import com.datastax.util.lang.ReflectUtil;
import com.datastax.util.lang.StringUtil;
import com.google.common.base.Strings;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.log4j.LogManager;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.sql.Dataset;

import javax.validation.constraints.NotNull;
import java.io.UnsupportedEncodingException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.URLDecoder;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import org.apache.log4j.Logger;

public class ActionHandler {

    private static List allActions = null;
    private static RunMode mode = RunMode.RUN;
    private static int hits = 0;
    private static String nodeId = null;
    private static String explorationFile = null;

    private final static String DatumLoader="com.datastax.insight.ml.spark.data.DatumLoader";
    private final static String DataSetWriter="com.datastax.insight.ml.spark.data.dataset.DataSetWriter";
    private final static String TransformationHandler="com.datastax.insight.ml.spark.data.dataset.TransformationHandler";
    private final static String PipelineHandler="com.datastax.insight.ml.spark.ml.pipeline.PipelineHandler";
    private final static String ParamGridBuilderWrapper="com.datastax.insight.ml.spark.ml.tuning.ParamGridBuilderWrapper";

    private static final Logger logger = LogManager.getLogger(ActionHandler.class);

    /**
     * 运行流程图
     * @param dag DAG
     * @param mode 运行模式
     * @param hits 采样数
     * @param nodeId 节点ID
     * @return
     */
    public static Object invoke(DAG dag, RunMode mode, int hits, String nodeId, String explorationFile) {
        ActionHandler.mode = mode;
        ActionHandler.hits = hits;
        ActionHandler.nodeId = nodeId;
        ActionHandler.explorationFile = explorationFile;
        Action action=parseDAG(dag);
        List actions=new ArrayList<>();
        actions.add(action);
        logger.info("===datastax-Insight Action Invoker Started===");
        System.out.println("===datastax-Insight Action Invoker Started===");
        Object ret= invoke(actions);
        logger.info("===datastax-Insight Action Invoker Ended===");
        System.out.println("===datastax-Insight Action Invoker Ended===");
        return ret;
    }

    public static Object invoke(Action action,Object[] parameters){
        try {
            Object ret = ReflectUtil.invokeMethod(action.getClassName(),
                    action.getMethodName(),
                    action.getParamTypes(), parameters);
            action.setResult(ret);
            return ret;
        } catch (Exception e) {
            e.printStackTrace();
        }

        return null;
    }

    //深度优先
    public static Object invoke(Action action) {
        Object ret=null;

        if(!action.getName().equals(DAG.START_VERTEX) && !action.getName().equals(DAG.END_VERTEX)) {
            Object[] parameters=prepareParams(action, null);

            ret=invoke(action,parameters);
        }

        for(Action a : action.getNextActions()){
            ret=invoke(a);
            a.setResult(ret);
        }

        return ret;
    }

    //广度优先
    public static Object invoke(List actions){

        Object ret=null;
        boolean flag=true;

        for(Action action : actions) {

            //20180104 fix bug: 节点遍历是有可能引起多次执行
            //如果节点已经执行过,则跳过执行
            if(action.getResult() != null) {
                continue;
            }

            if (!action.getName().equals(DAG.START_VERTEX) && !action.getName().equals(DAG.END_VERTEX)) {

                logger.info("===datastax-Insight Component Started==="+
                        action.getId()+":"+action.getType()+"===");
                System.out.println("===datastax-Insight Component Started==="+
                        action.getId()+":"+action.getType()+"===");

                try {
                    Object[] parameters = prepareParams(action, null);
                    ret = invoke(action, parameters);
                    sample(action);
                    //update flow status
                    PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                            "updateFlowStatus",
                            new String[]{Long.class.getTypeName(), Long.class.getTypeName(), Integer.class.getTypeName()},
                            new Object[]{PersistService.getFlowId(), PersistService.getBatchId(), 0});

                    if (ret == null) {
                        Method method = ReflectUtil.findMethod(Class.forName(action.getClassName()), action.getMethodName(), action.getParamTypes());
                        if (!method.getReturnType().isAssignableFrom(void.class)) {
                            logger.info("===datastax-Insight Component Ended==="+
                                    action.getId()+":0"+":"+action.getType()+"===");
                            System.out.println("===datastax-Insight Component Ended==="+
                                    action.getId()+":0"+":"+action.getType()+"===");
                            throw new Exception("Error here");
                        } else {
                            logger.info("===datastax-Insight Component Ended==="+
                                    action.getId()+":1"+":"+action.getType()+"===");
                            System.out.println("===datastax-Insight Component Ended==="+
                                    action.getId()+":1"+":"+action.getType()+"===");
                        }
                    } else {
                        logger.info("===datastax-Insight Component Ended==="+
                                action.getId()+":1"+":"+action.getType()+"===");
                        System.out.println("===datastax-Insight Component Ended==="+
                                action.getId()+":1"+":"+action.getType()+"===");

                    }
                    //// TODO: 2018/2/5   nodeId stop and write hdfs
                    if (!RunMode.RUN.equals(mode)
                            && (nodeId != null && !nodeId.trim().isEmpty())
                            && action.getId() == Long.valueOf(nodeId)) {
                        //stop
                        flag = false;
                        //write hdfs
                        if (RunMode.RUN_EXPLORATION.equals(mode)) {
                            writeExploration(action);
                            //log
                        }
                    }
                } catch (Exception ex){
                    //update the flow status if execution failed
                    PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                            "updateFlowStatus",
                            new String[]{Long.class.getTypeName(), Long.class.getTypeName(), Integer.class.getTypeName()},
                            new Object[]{PersistService.getFlowId(), PersistService.getBatchId(), -1});

                    for(StackTraceElement element : ex.getStackTrace()) {
                        logger.info(element.toString());
                        System.out.println(element.toString());
                    }
                    flag=false;
                    break;
                }
            }
        }

        if(flag) {
            for (Action action : actions) {
                List subActions = action.getNextActions();
                if (subActions != null && subActions.size() > 0) {
                    ret = invoke(subActions);
                }
            }
        }

        return ret;
    }

    /**
     * 为Action准备参数,参数分为三类:
     * 1、前一个Action的输出
     * 2、设置的固定值
     * 3、容器的全局设置值(主要有pipeline和transformer两种容器)
     * @param action 当前的Action
     * @param containerData 容器的全局设置值(容器中的控件通过pipeline.data或者transformer.data可以获取)
     * @return Action的参数列表
     */
    private static Object[] prepareParams(Action action, Object containerData){
        if(action.getClassName().equals(PipelineHandler) && action.getMethodName().equals("fit")){
            return preparePipelineParams(action);
        }

        if(action.getClassName().equals(TransformationHandler)) {
            return prepareTransformerParams(action);
        }

        if(action.getClassName().equals(ParamGridBuilderWrapper)) {
            return prepareParamGridParams(action);
        }

        Object[] parameters=action.getParamValues();
        if(parameters==null) return null;
        Object[] ps=new Object[parameters.length];
        int index=0;
        for(int i=0;i
                        //throw new Exception("parameter is not correct!");
                    }
                } else {
                    Object result = getPreActionResult(action, matcher.group(1), -1);
                    String fieldName = matcher.group(2);

//                    if (!Strings.isNullOrEmpty(fieldName)) {
//                        Field field = ReflectUtil.findField(result.getClass(), fieldName);
//
//                        if (field != null) {
//                            field.setAccessible(true);
//                            ps[i]= ReflectUtil.getField(field, result);
//                        } else {
//                            Method method = ReflectUtil.findMethod(result.getClass(), fieldName);
//
//                            if(method != null) {
//                                ps[i]= ReflectUtil.invokeMethod(method, result);
//                            }
//                        }
//
//                        ps[i]=ReflectUtil.getField(field, result);
//                    } else {
//                        ps[i]=result;
//                    }

                    if(Strings.isNullOrEmpty(fieldName)) {
                        ps[i]=result;
                    } else {
                        ps[i] = getPropertyValue(result, fieldName);
                    }
                }
            }else if(p instanceof String && p.toString().contains(Constants.OUTPUT_ATUO)){                  //用户无需维护索引
//                long paramActionId=action.getParamOrders()[index];
                long paramActionId = action.getPreActionId(index);
                Object result=getPreActionResult(action,p.toString(),paramActionId);
                index++;
                ps[i]=result;
            }else if(p instanceof String && p.toString().contains(Constants.PIPELINE_DATA)){
                if(containerData != null) {
                    ps[i] = containerData;
                } else {
                    //获取管道Action
//                Action pipelineAction=getPipelineAction2(action);
                    Action pipelineAction = getAction(action, "pipeline");

                    //pipeline.data需要从管道参数中获取(第一个参数)
                    String dataParameter = pipelineAction.getParamValues()[0].toString();
                    //判断参数值是否需要计算,如果是${output}_格式的,就需要进行实时计算获取参数值
                    if (dataParameter.contains(Constants.OUTPUT)) {
                        ps[i] = getPreActionResult(pipelineAction, dataParameter, -1);
                    } else {
                        ps[i] = pipelineAction.getParamValues()[0];
                    }
                }
            }else if(p instanceof String && p.toString().contains(Constants.TRANSFORMER_DATA)) {

                if(containerData != null) {
                    ps[i] = containerData;
                } else {
                    //获取transformerAction
                    Action transformerAction= getAction(action, "transformer");
                    //transformer.data需要从transformer参数中获取(第一个参数)
                    String dataParameter = transformerAction.getParamValues()[0].toString();
                    //判断参数值是否需要计算,如果是${output}_格式的,就需要进行实时计算获取参数值
                    if(dataParameter.contains(Constants.OUTPUT)) {
                        ps[i] = getPreActionResult(transformerAction, dataParameter, -1);
                    } else {
                        ps[i] = transformerAction.getParamValues()[0];
                    }
                }
            }
            //是否是系统变量,20180731,andershong
            //-----需要在Insight管理系统中的两个地方进行变量控制
            //-----1、用户登录后,要把当前用户的变量写入上下文中;2、用户修改自己的配置时
            else if(p instanceof String && isSysProp(p.toString())) {
                String prop=getSystProp(p.toString());
                ps[i]=((String) p).replace(prop, Context.SYSTEM_PROPERTIES.get(prop).toString());
            }else {
                //20171219 修改tab问题:组件参数输入tab,不做转换则json解析会出错,所以在保存json钱先做encode,执行时再做decode
//                ps[i]=p;
                if(p instanceof String) {
                    try {
                        ps[i] = URLDecoder.decode(p.toString(), "UTF-8");
                    } catch (UnsupportedEncodingException e) {
                        ps[i] = p;
                    }
                } else {
                    ps[i] = p;
                }
            }
        }
        return ps;
    }

    private static boolean isSysProp(String parameter){
        for(String key : Context.SYSTEM_PROPERTIES.keySet()){
            if(parameter.contains(key)) return true;
        }
        return false;
    }

    private static String getSystProp(String parameter){
        for(String key : Context.SYSTEM_PROPERTIES.keySet()){
            if(parameter.contains(key)) return key;
        }
        return null;
    }

    //管道的参数解析
    private static Object[] preparePipelineParams(Action action) {
        Object[] parameters = action.getParamValues();
        long[] actionIds = action.getActions();
        //管道训练包含两个参数,其中第一个参数是数据,第二个参数是PipelineStage数组
        Object[] result = new Object[2];

        //第一个参数为数据
        if (parameters != null && parameters.length > 0) {
            //判断参数值是否需要计算,如果是${output}_格式的,就需要进行实时计算获取参数值
            String p = parameters[0].toString();
//            if (p.contains(Constants.OUTPUT)) {
//                result[0] = getPreActionResult(action, p, -1);
//            } else if(p.equals("${output}")) {
//                long paramActionId = action.getPreActionId(0);
//                result[0] = getPreActionResult(action,p.toString(),paramActionId);
//            } else {
//                result[0] = parameters[0];
//            }
            //添加对${output}._x、${output}._x[y]、${output}._x.y、${output}.x、${output}.[x]的支持
            if (p.startsWith(Constants.OUTPUT)) {
                result[0] = getPreActionResult(action, p, -1);
            } else if(p.startsWith(Constants.OUTPUT_ATUO)) {
                long paramActionId = action.getPreActionId(0);
                result[0] = getPreActionResult(action, p, paramActionId);
            } else {
                result[0] = parameters[0];
            }
        }

        //第二个参数为PipelineStage数组(管道内组件)
        PipelineStage[] pipelineStages = new PipelineStage[actionIds.length];
        result[1] = pipelineStages;

        //管道内组件
//        for (int i = 0; i < actionIds.length; i++) {
//            long actionId = actionIds[i];
//            Action subAction = getActionFromList(actionId);
//            try {
//                pipelineStages[i] = (PipelineStage) invoke(subAction, prepareParams(subAction));
//            } catch (Exception e) {
//                e.printStackTrace();
//            }
//        }
        //管道内的组件顺序没有办法保证,需要在运行时按照level进行排序
        List pipelineActions = orderByLevel(actionIds);
        for (int i = 0; i < pipelineActions.size(); i++) {
            Action subAction = pipelineActions.get(i);
            try {
                pipelineStages[i] = (PipelineStage) invoke(subAction, prepareParams(subAction, result[0]));
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return result;
    }

    private static Object[] prepareTransformerParams(Action action) {
        Object[] parameters = action.getParamValues();
        long[] actionIds = action.getActions();
        //Transformer包含一个参数: Dataset
        Object[] result = new Object[1];
        Object transformerData = null;

        if (parameters != null && parameters.length > 0) {
            //判断参数值是否需要计算,如果是${output}_格式的,就需要进行实时计算获取参数值
            String p = parameters[0].toString();
            if (p.contains(Constants.OUTPUT)) {
                transformerData = getPreActionResult(action, p, -1);
            } else {
                transformerData = parameters[0];
            }
        }

        //Transformer内组件按顺序依次执行
        for (long id : actionIds) {
            Action subAction = getActionFromList(id);
            transformerData = invoke(subAction, prepareParams(subAction, transformerData));
        }
        return new Object[] { transformerData };
    }

    private static Object[] prepareParamGridParams(Action action) {
        Object[] parameters = action.getParamValues();

//        if(parameters == null || parameters.length != 2) {
//            return null;
//        }

        String args = parameters[0].toString();
        String[] grids = args.split("::");
        Map realParams = new HashMap<>();
//        String regex = "(\\$\\{output\\}\\._[\\d]+)\\.{1}([A-Za-z0-9_]*)";
        String regex = "([\\d]+).([A-Za-z0-9_]+)";
        Pattern pattern = Pattern.compile(regex);

        for (String grid : grids) {
            String[] kv = grid.split(":");
            Matcher matcher = pattern.matcher(kv[0]);

            //if(!matcher.matches()) {
                //
//                    throw new Exception("parameter is not correct!");
            //}

//            Object realKey = getPreActionResult(action, matcher.group(1), -1);
            Object realKey = getAutoResult(action, Long.parseLong(matcher.group(1)));
            String fieldName = matcher.group(2);

            if (!Strings.isNullOrEmpty(fieldName)) {
                Field field = ReflectUtil.findField(realKey.getClass(), fieldName);
                if (field != null) {
                    field.setAccessible(true);
                    Object realValue = ReflectUtil.getField(field, realKey);

                    if (kv.length == 2) {
                        realParams.put(realValue, kv[1]);
                    } else {
                        realParams.put(realValue, null);
                    }
                } else {
                    Method method = ReflectUtil.findMethod(realKey.getClass(), fieldName);

                    if(method != null) {
                        Object realValue = ReflectUtil.invokeMethod(method, realKey);

                        if (kv.length == 2) {
                            realParams.put(realValue, kv[1]);
                        } else {
                            realParams.put(realValue, null);
                        }
                    }
                }
            } else {
                if(kv.length == 2) {
                    realParams.put(realKey, kv[1]);
                } else {
                    realParams.put(realKey, null);
                }
            }
        }

        return new Object[] { null, realParams };
    }

    /**
     * 按照level对管道内组件进行排序
     * @param actionIds 组件列表(管道中member顺序)
     * @return 排序后的组件列表
     */
    private static List orderByLevel(long[] actionIds) {
        return Arrays.stream(ArrayUtils.toObject(actionIds)).map(ActionHandler::getActionFromList)
                .sorted(Comparator.comparing(Action::getLevel))
                .collect(Collectors.toList());
    }

    //支持表达式
    //// TODO: 16-9-22 支持更牛逼的表达式
    private static Object getPreActionResult(Action action,String param, long paramActionId){
        if(param.contains("[") && param.contains("]")){
            Object result;
            if(param.startsWith(Constants.OUTPUT)) {
                int index = Integer.parseInt(StringUtil.substringIndent(param, Constants.OUTPUT, "["));
                result = action.getPreActions().get(index - 1).getResult();

                //TODO: 2017/11/24 暂时没有考虑${output}._1.labels这样的情况
                //TODO: 2017/11/25 属性获取"_xx"与".xx"的顺序可以是任意的,这一点也没有考虑
                if(result == null) {
                    String parsedParam;

                    if(param.contains("[")) {
                        parsedParam = param.substring(0, param.indexOf("["));
                    } else {
                        parsedParam = param;
                    }

                    result = getPreActionResult(action, parsedParam);
                }
            } else if (param.startsWith(Constants.OUTPUT_ATUO)) {
                //添加对${output}[x]的支持
                result = getAutoResult(action, paramActionId);
            } else {
                result = getActionFromList(paramActionId).getResult();
            }

            if(result==null) return null;

            //支持集合索引
            int subIndex= Integer.parseInt(StringUtil.substringIndent(param,"[","]"));
            if(result instanceof List){
                return ((List)result).get(subIndex-1);
            }else if(result.getClass().isArray()){
                return ((Object[]) result)[subIndex-1];
            }
        } else if (param.startsWith(Constants.OUTPUT)) {
            //按照广度优先进行计算时,当前节点的依赖可能还么有执行,需要回溯执行获取参数
            //                 A
            //                / \
            //               /   C
            //              B    |
            //               \   D
            //                \ /
            //                 E
            //执行顺序是A->B->C->E,在执行E的时候,需要先执行D获得参数
            //如果D的依赖也没有执行,则还需要继续向上回溯
            return getPreActionResult(action, param);
        } else if (param.startsWith(Constants.OUTPUT_ATUO)) {
            //支持${output}和${output}.xxx
             Object result = getAutoResult(action, paramActionId);

             if(param.contains(".")) {
                 String fieldName = param.substring(param.indexOf("."));
                 return getPropertyValue(result, fieldName);
             } else {
                 return result;
             }
        }
//        else {
//            if(param.equals("${output}")){
////                Action preAction = getActionFromList(paramActionId);
////
////                Optional edge = action.getDag().findEdge(preAction.getId(), action.getId());
////                int outputIndex = 0;
////
////                if(edge.isPresent()) {
////                    Edge realEdge = edge.get();
////
////                    if(realEdge.getParameters() != null && realEdge.getParameters().size() > 0) {
////                        outputIndex = realEdge.getParameters().get(0).getOutput();
////                    }
////                }
////
////                Object result = preAction.getResult();
////                result = result == null ? execAction(preAction) : result;
////
////                if(result instanceof List){
////                    return ((List)result).get(outputIndex);
////                } else if(result.getClass().isArray()){
////                    return ((Object[]) result)[outputIndex];
////                }
////
////                return result;
//                return getAutoResult(action, paramActionId);
//            } else {
//                //按照广度优先进行计算时,当前节点的依赖可能还么有执行,需要回溯执行获取参数
//                //                 A
//                //                / \
//                //               /   C
//                //              B    |
//                //               \   D
//                //                \ /
//                //                 E
//                //执行顺序是A->B->C->E,在执行E的时候,需要先执行D获得参数
//                //如果D的依赖也没有执行,则还需要继续向上回溯
////                int index = Integer.parseInt(param.replace(Constants.OUTPUT, ""));
////                Object result = action.getPreActions().get(index - 1).getResult();
//                Object result = getPreActionResult(action, param);
//
//                return result;
//            }
//        }
        return null;
    }

    /**
     * 根据DAG,自动获取前序节点的输出,形如:${output}
     * @param action 当前节点
     * @param preActionId 前序节点id
     * @return 前序节点输出
     */
    private static Object getAutoResult(Action action, long preActionId) {
        Action preAction = getActionFromList(preActionId);

        Optional edge = action.getDag().findEdge(preAction.getId(), action.getId());
        int outputIndex = 0;

        if(edge.isPresent()) {
            Edge realEdge = edge.get();

            if(realEdge.getParameters() != null && realEdge.getParameters().size() > 0) {
                outputIndex = realEdge.getParameters().get(0).getOutput();
            }
        }

        Object result = preAction.getResult();
        result = result == null ? execAction(preAction) : result;

        if(preAction.getType().equals("ParamGrid")) {
            return result;
        } else if(result instanceof List){
            return ((List)result).get(outputIndex);
        } else if(result.getClass().isArray()){
            return ((Object[]) result)[outputIndex];
        }

        return result;
    }

    /**
     * 从对象中获取属性值,形如:obj.xxx
     * @param obj 获取属性的对象
     * @param fieldName 属性名
     * @return 属性值
     */
    private static Object getPropertyValue(Object obj, @NotNull String fieldName) {

        Object result = null;

        if (obj != null) {

            Field field = ReflectUtil.findField(obj.getClass(), fieldName);

            if (field != null) {
                field.setAccessible(true);
                result = ReflectUtil.getField(field, obj);
            } else {
                Method method = ReflectUtil.findMethod(obj.getClass(), fieldName);

                if (method != null) {
                    result = ReflectUtil.invokeMethod(method, obj);
                } else {
                    result = null;
                }
            }

//            result = ReflectUtil.getField(field, result);
        }

        return result;
    }

    /**
     * 获得前一个Action的结果
     * 如果前一个Action没有执行,需要立刻执行获取结果
     * 如果前一个的前一个Action没有执行,需要立刻执行获取结果
     * 以此类推
     * @param action 当前的Action
     * @param param 当前的Action对应的参数
     * @return 前一个Action的执行结果
     */
    private static Object getPreActionResult(Action action, String param) {
        int index = Integer.parseInt(param.replace(Constants.OUTPUT, ""));
        Action preAction = action.getPreActions().get(index -1);
        Object result = preAction.getResult();

        if (result == null) {

            result = execAction(preAction);
//            //add print state log for pre-action START
//            System.out.println("===datastax-Insight Component Started==="+
//                    preAction.getId()+":"+preAction.getType()+"===");
//
//            try {
//                Object[] parameters = prepareParams(preAction, null);
//                result = invoke(preAction, parameters);
//
//                System.out.println("===datastax-Insight Component Ended==="+
//                        preAction.getId()+":1"+":"+preAction.getType()+"===");
//            }catch (Exception ex){
//                for(StackTraceElement element : ex.getStackTrace()) {
//                    System.out.println(element.toString());
//                }
//                System.out.println("===datastax-Insight Component Ended==="+
//                        preAction.getId()+":0"+":"+preAction.getType()+"===");
//            }
//            //add print state log for pre-action END
        }

        return result;
    }

    private static Object execAction(Action action) {

        Object result = null;

        //add print state log for pre-action START
        logger.info("===datastax-Insight Component Started==="+
                action.getId()+":"+action.getType()+"===");
        System.out.println("===datastax-Insight Component Started==="+
                action.getId()+":"+action.getType()+"===");

        try {
            Object[] parameters = prepareParams(action, null);
            result = invoke(action, parameters);

            logger.info("===datastax-Insight Component Ended==="+
                    action.getId()+":1"+":"+action.getType()+"===");
            System.out.println("===datastax-Insight Component Ended==="+
                    action.getId()+":1"+":"+action.getType()+"===");
        }catch (Exception ex){
            for(StackTraceElement element : ex.getStackTrace()) {
                logger.info(element.toString());
                System.out.println(element.toString());
            }
            logger.info("===datastax-Insight Component Ended==="+
                    action.getId()+":0"+":"+action.getType()+"===");
            System.out.println("===datastax-Insight Component Ended==="+
                    action.getId()+":0"+":"+action.getType()+"===");
        }
        //add print state log for pre-action END

        return result;
    }

    /**
     * 根据当期的Action和目标Action的类型获取目标Action
     * 目标Action为容器类型,支持:pipeline和transformer
     * 当前Action是目标Action的actions中的一员
     * @param action 当前的Action
     * @param actionType 目标Action的类型,支持:pipeline和transformer
     * @return 目标Action
     */
    private static Action getAction(Action action, String actionType) {

        Optional transformer = allActions.stream()
                .filter(p->!Strings.isNullOrEmpty(p.getType())
                        && p.getType().equals(actionType)
                        && p.getActions() != null
                        && Arrays.stream(p.getActions()).filter(id -> id == action.getId()).findAny().isPresent())
                .findAny();

        return transformer.orElse(null);
    }

    //从start节点开始解析DAG
    public static Action parseDAG(DAG dag){
        allActions=new ArrayList<>();
        Action action=new Action();
        action.setName(DAG.START_VERTEX);
        Vertex startV=dag.getVByName(DAG.START_VERTEX);
        action.setId(startV.getId());
        allActions.add(action);
        parseVertex(dag,startV,action);
        action.setDag(dag);
        return action;
    }

    private static void parseVertex(DAG dag,Vertex vertex,Action action){
        List nextList=dag.next(vertex.getId());
        List sortedList=new ArrayList<>();
        sortVList(nextList,sortedList,dag);

        for(Vertex v : sortedList) {

            //先从列表中拿Action,避免多路Action带来的重复执行
            Action nextAction=getActionFromList(v.getId());
            if(nextAction==null) {
                //添加Action时区分前序Action还是后续Action
                nextAction = vertex2Action(dag, v, action, false);
                //只有新操作才加入执行队列,避免重复执行
                action.getNextActions().add(nextAction);
            }

//            allActions.add(nextAction);

            //20171209 fix bug: use contains of list should implement equals function
//            if(!nextAction.getPreActions().contains(action)) {
//                nextAction.getPreActions().add(action);
//            }
            if(!nextAction.containsPreAction(action)) {
//                Optional edge = dag.findEdge(action.getId(), nextAction.getId());
//                if(edge.isPresent()) {
//                    Edge realEdge = edge.get();
//
//                    if(realEdge.getParameters() != null && realEdge.getParameters().size() > 0) {
//                        int actionIndex = realEdge.getParameters().get(0).getInput();
//                        nextAction.setPreActionOrder(actionIndex, action.getId());
//                    }
//                }
//                nextAction.getPreActions().add(action);
                addPreAction(nextAction, action, dag);
            }

            if(nextAction.getActions() != null && nextAction.getActions().length > 0) {
                //第一个节点并不能保证是管道中流程的开始节点,所以需要遍历所有管道中的节点
//                Vertex pipelineVertex = dag.getVById(nextAction.getActions()[0]);
                for (long pId : nextAction.getActions()) {
                    Vertex pipelineVertex = dag.getVById(pId);
                    if (pipelineVertex != null) {
                        Action pipelineAction = getActionFromList(pipelineVertex.getId());
                        if (pipelineAction == null) {
                            //添加Action时区分前序Action还是后续Action
                            pipelineAction = vertex2Action(dag, pipelineVertex, nextAction, false);
                            parseVertex(dag, pipelineVertex, pipelineAction);
                        }
                    }
                }
            }

            //添加对前序节点的处理(入度为0的节点) START
            parsePreVertex(v, nextAction, dag);
            //添加对前序节点的处理(入度为0的节点) END

            parseVertex(dag,v,nextAction);
        }

        action.sortPreActions();
    }

    /**
     * 将入度为0的节点添加到执行序列中
     * @param vertex 当前的顶点
     * @param action 当前的Action
     * @param dag dag
     */
    private static void parsePreVertex(Vertex vertex, Action action, DAG dag) {
        List preVertexs = dag.prev(vertex.getId());
        List preActions = action.getPreActions().stream().map(Component::getId).collect(Collectors.toList());
        for (Vertex preV : preVertexs) {
            if(preActions.stream().filter(a->a == preV.getId()).count() == 0) {
                Action preAction=getActionFromList(preV.getId());
                if(preAction==null) {
                    preAction = vertex2Action(dag, preV, action, true);
                }

                //20171209 fix bug: use contains of list should implement equals function
//                if(!preAction.getNextActions().contains(action)) {
//                    preAction.getNextActions().add(action);
//                }
                if(!preAction.containsNextAction(action)) {
                    preAction.getNextActions().add(action);
                }

//                if(!action.getPreActions().contains(preAction)) {
//                    action.getPreActions().add(preAction);
//                }

                if(!action.containsPreAction(preAction)) {
//                    Optional edge = dag.findEdge(preAction.getId(), action.getId());
//                    if(edge.isPresent()) {
//                        Edge realEdge = edge.get();
//
//                        if(realEdge.getParameters() != null && realEdge.getParameters().size() > 0) {
//                            int actionIndex = realEdge.getParameters().get(0).getInput();
//                            action.setPreActionOrder(actionIndex, preAction.getId());
//                        }
//                    }
//                    action.getPreActions().add(preAction);
                    addPreAction(action, preAction, dag);
                }

                parsePreVertex(preV, preAction, dag);
            }
        }
    }

    private static void addPreAction(Action action, Action preAction, DAG dag) {
        Optional edge = dag.findEdge(preAction.getId(), action.getId());
        if(edge.isPresent()) {
            Edge realEdge = edge.get();

            if(realEdge.getParameters() != null && realEdge.getParameters().size() > 0) {
                int actionIndex = realEdge.getParameters().get(0).getInput();
                action.setPreActionOrder(actionIndex, preAction.getId());
            }
        }
        action.addPreAction(preAction);
    }

    private static Action vertex2Action(DAG dag, Vertex vertex, Action parent, boolean previous) {

        Action action = new Action();
        action.setId(vertex.getId());
        action.setName(vertex.getName());
        action.setType(vertex.getType());
        action.setDag(dag);

        //添加前序节点,level - 1;添加后续节点,level + 1
        if(previous) {
            action.setLevel(parent.getLevel() - 1);
        } else {
            action.setLevel(parent.getLevel() + 1);
        }

        if (!Strings.isNullOrEmpty(vertex.getActions())) {
            long[] actions = Arrays.stream(vertex.getActions().split(";")).mapToLong(Long::parseLong).toArray();
            action.setActions(actions);
        }

        if (vertex.getParameters() != null) {
            Parameter[] params = vertex.getParameters().toArray(new Parameter[0]);
            action.setParameters(params);
        }

        String pOrders = vertex.getParamOrders();

        if (pOrders != null && pOrders.length() > 0) {
            String[] tOrders = pOrders.split(Consts.DELIMITER);
            long[] orders = new long[tOrders.length];
            for (int i = 0; i < orders.length; i++) {
                orders[i] = Long.parseLong(tOrders[i]);
            }
            action.setParamOrders(orders);
        }

        allActions.add(action);

//        parent.getPreActions().add(action);

        //component->action的映射
        mapAction(action);

        //参数网格里面有多余的参数,需要抛弃
        if(!Strings.isNullOrEmpty(action.getClassName()) &&
                action.getClassName().equals(ParamGridBuilderWrapper)) {

            Parameter[] params = action.getParameters();
            action.setParameters(new Parameter[] { params[0], params[1]});
        }

        return action;
    }

    private static Action getActionFromList(long actionId){
        for(Action action : allActions){
            if (action.getId()==actionId) return action;
        }
        return null;
    }

    //有向无环图拓扑排序
    //方法:递归删除入度为0的节点
    private static void sortVList(List vList,List sortedList,DAG dag){
        while(vList.size()>0){
            for(Vertex v : vList){
                if(getPrev(v,vList,dag).size()==0){
                    sortedList.add(v);
                }
            }
            for(Vertex v : sortedList){
                vList.remove(v);
            }
            sortVList(vList,sortedList,dag);
        }
    }

    private static List getPrev(Vertex current,List vList,DAG dag){
        List prevList=dag.prev(current.getId());
        List filterList=new ArrayList<>();
        for(Vertex v : prevList){
            for(Vertex l : vList){
                if(v != null && l != null && v.getId()==l.getId()){
                    filterList.add(v);
                }
            }
        }
        return filterList;
    }

    private static void mapAction(Action action){
        action.fromComponent(action);
    }

    /**
     * 单运行模式为数据探索和小规模数据试运行则对DataSet进行抽样
     */
    private static void sample(Action action) {
        if (!(RunMode.RUN_EXPLORATION.equals(mode) || RunMode.RUN_SAMPLE.equals(mode))) return;
        if (!DatumLoader.equals(action.getType())) return;
        Dataset ds = (Dataset) action.getResult();
        if (ds.count() > hits) {
            action.setResult(ds.javaRDD().takeSample(false, hits));
        }
    }

    private static void writeExploration(Action action) throws Exception {
        //action.i
        //DataSetWriter.save((Dataset)action.getResult(), "csv", "overwrite", explorationFile);
        ReflectUtil.invokeMethod(DataSetWriter,"save",
                (Dataset)action.getResult(), "csv", "overwrite", explorationFile);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy