All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.airavata.gfac.gsissh.util.GFACGSISSHUtils Maven / Gradle / Ivy

The newest version!
/*
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
*/
package org.apache.airavata.gfac.gsissh.util;

import org.airavata.appcatalog.cpi.AppCatalog;
import org.apache.airavata.common.exception.ApplicationSettingsException;
import org.apache.airavata.common.utils.ServerSettings;
import org.apache.airavata.credential.store.credential.impl.certificate.CertificateCredential;
import org.apache.airavata.credential.store.store.CredentialReader;
import org.apache.airavata.gfac.GFacException;
import org.apache.airavata.gfac.RequestData;
import org.apache.airavata.gfac.core.context.JobExecutionContext;
import org.apache.airavata.gfac.core.context.MessageContext;
import org.apache.airavata.gfac.core.utils.GFacUtils;
import org.apache.airavata.gfac.gsissh.security.GSISecurityContext;
import org.apache.airavata.gfac.gsissh.security.TokenizedMyProxyAuthInfo;
import org.apache.airavata.gsi.ssh.api.Cluster;
import org.apache.airavata.gsi.ssh.api.ServerInfo;
import org.apache.airavata.gsi.ssh.api.job.JobDescriptor;
import org.apache.airavata.gsi.ssh.api.job.JobManagerConfiguration;
import org.apache.airavata.gsi.ssh.impl.GSISSHAbstractCluster;
import org.apache.airavata.gsi.ssh.impl.PBSCluster;
import org.apache.airavata.gsi.ssh.util.CommonUtils;
import org.apache.airavata.model.appcatalog.appdeployment.ApplicationDeploymentDescription;
import org.apache.airavata.model.appcatalog.appdeployment.ApplicationParallelismType;
import org.apache.airavata.model.appcatalog.appinterface.DataType;
import org.apache.airavata.model.appcatalog.appinterface.InputDataObjectType;
import org.apache.airavata.model.appcatalog.appinterface.OutputDataObjectType;
import org.apache.airavata.model.appcatalog.computeresource.*;
import org.apache.airavata.model.appcatalog.gatewayprofile.ComputeResourcePreference;
import org.apache.airavata.model.workspace.experiment.ComputationalResourceScheduling;
import org.apache.airavata.model.workspace.experiment.TaskDetails;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.util.*;


public class GFACGSISSHUtils {
    private final static Logger logger = LoggerFactory.getLogger(GFACGSISSHUtils.class);

    public static final String PBS_JOB_MANAGER = "pbs";
    public static final String SLURM_JOB_MANAGER = "slurm";
    public static final String SUN_GRID_ENGINE_JOB_MANAGER = "UGE";
    public static final String LSF_JOB_MANAGER = "lsf";

    public static int maxClusterCount = 5;
    public static Map> clusters = new HashMap>();

    public static void addSecurityContext(JobExecutionContext jobExecutionContext) throws GFacException, ApplicationSettingsException {
        JobSubmissionInterface jobSubmissionInterface = jobExecutionContext.getPreferredJobSubmissionInterface();
        JobSubmissionProtocol jobProtocol = jobSubmissionInterface.getJobSubmissionProtocol();
        try {
            AppCatalog appCatalog = jobExecutionContext.getAppCatalog();
            SSHJobSubmission sshJobSubmission = appCatalog.getComputeResource().getSSHJobSubmission(jobSubmissionInterface.getJobSubmissionInterfaceId());
            if (jobProtocol == JobSubmissionProtocol.GLOBUS || jobProtocol == JobSubmissionProtocol.UNICORE
                    || jobProtocol == JobSubmissionProtocol.CLOUD || jobProtocol == JobSubmissionProtocol.LOCAL) {
                logger.error("This is a wrong method to invoke to non ssh host types,please check your gfac-config.xml");
            } else if (jobProtocol == JobSubmissionProtocol.SSH && sshJobSubmission.getSecurityProtocol() == SecurityProtocol.GSI) {
                String credentialStoreToken = jobExecutionContext.getCredentialStoreToken(); // this is set by the framework
                RequestData requestData = new RequestData(jobExecutionContext.getGatewayID());
                requestData.setTokenId(credentialStoreToken);
                PBSCluster pbsCluster = null;
                GSISecurityContext context = null;

                TokenizedMyProxyAuthInfo tokenizedMyProxyAuthInfo = new TokenizedMyProxyAuthInfo(requestData);
                CredentialReader credentialReader = GFacUtils.getCredentialReader();
                if (credentialReader != null) {
                    CertificateCredential credential = null;
                    try {
                        credential = (CertificateCredential) credentialReader.getCredential(jobExecutionContext.getGatewayID(), credentialStoreToken);
                        requestData.setMyProxyUserName(credential.getCommunityUser().getUserName());
                    } catch (Exception e) {
                        logger.error(e.getLocalizedMessage());
                    }
                }

                String key = requestData.getMyProxyUserName() + jobExecutionContext.getHostName()+
                        sshJobSubmission.getSshPort();
                boolean recreate = false;
                synchronized (clusters) {
                    if (clusters.containsKey(key) && clusters.get(key).size() < maxClusterCount) {
                        recreate = true;
                    } else if (clusters.containsKey(key)) {
                        int i = new Random().nextInt(Integer.MAX_VALUE) % maxClusterCount;
                        if (clusters.get(key).get(i).getSession().isConnected()) {
                            pbsCluster = (PBSCluster) clusters.get(key).get(i);
                        } else {
                            clusters.get(key).remove(i);
                            recreate = true;
                        }
                        if (!recreate) {
                            try {
                                pbsCluster.listDirectory("~/"); // its hard to trust isConnected method, so we try to connect if it works we are good,else we recreate
                            } catch (Exception e) {
                                clusters.get(key).remove(i);
                                logger.info("Connection found the connection map is expired, so we create from the scratch");
                                maxClusterCount++;
                                recreate = true; // we make the pbsCluster to create again if there is any exception druing connection
                            }
                            logger.info("Re-using the same connection used with the connection string:" + key);
                            context = new GSISecurityContext(tokenizedMyProxyAuthInfo.getCredentialReader(), requestData, pbsCluster);
                        }
                    } else {
                        recreate = true;
                    }

                    if (recreate) {
                        ServerInfo serverInfo = new ServerInfo(requestData.getMyProxyUserName(), jobExecutionContext.getHostName(),
                                sshJobSubmission.getSshPort());

                        JobManagerConfiguration jConfig = null;
                        String installedParentPath = sshJobSubmission.getResourceJobManager().getJobManagerBinPath();
                        String jobManager = sshJobSubmission.getResourceJobManager().getResourceJobManagerType().toString();
                        if (jobManager == null) {
                            logger.error("No Job Manager is configured, so we are picking pbs as the default job manager");
                            jConfig = CommonUtils.getPBSJobManager(installedParentPath);
                        } else {
                            if (PBS_JOB_MANAGER.equalsIgnoreCase(jobManager)) {
                                jConfig = CommonUtils.getPBSJobManager(installedParentPath);
                            } else if (SLURM_JOB_MANAGER.equalsIgnoreCase(jobManager)) {
                                jConfig = CommonUtils.getSLURMJobManager(installedParentPath);
                            } else if (SUN_GRID_ENGINE_JOB_MANAGER.equalsIgnoreCase(jobManager)) {
                                jConfig = CommonUtils.getUGEJobManager(installedParentPath);
                            }else if(LSF_JOB_MANAGER.equalsIgnoreCase(jobManager)) {
                                jConfig = CommonUtils.getLSFJobManager(installedParentPath);
                            }
                        }
                        pbsCluster = new PBSCluster(serverInfo, tokenizedMyProxyAuthInfo, jConfig);
                        context = new GSISecurityContext(tokenizedMyProxyAuthInfo.getCredentialReader(), requestData, pbsCluster);
                        List pbsClusters = null;
                        if (!(clusters.containsKey(key))) {
                            pbsClusters = new ArrayList();
                        } else {
                            pbsClusters = clusters.get(key);
                        }
                        pbsClusters.add(pbsCluster);
                        clusters.put(key, pbsClusters);
                    }
                }

                jobExecutionContext.addSecurityContext(jobExecutionContext.getHostName(), context);
            }
        } catch (Exception e) {
            throw new GFacException("An error occurred while creating GSI security context", e);
        }
    }

    public static JobDescriptor createJobDescriptor(JobExecutionContext jobExecutionContext, Cluster cluster) {
        JobDescriptor jobDescriptor = new JobDescriptor();
        TaskDetails taskData = jobExecutionContext.getTaskData();
        ResourceJobManager resourceJobManager = jobExecutionContext.getResourceJobManager();
        try {
			if(ServerSettings.getSetting(ServerSettings.JOB_NOTIFICATION_ENABLE).equalsIgnoreCase("true")){
				jobDescriptor.setMailOptions(ServerSettings.getSetting(ServerSettings.JOB_NOTIFICATION_FLAGS));
				String emailids = ServerSettings.getSetting(ServerSettings.JOB_NOTIFICATION_EMAILIDS);

				if(jobExecutionContext.getTaskData().isSetEmailAddresses()){
					List emailList = jobExecutionContext.getTaskData().getEmailAddresses();
					String elist = GFacUtils.listToCsv(emailList, ',');
					if(emailids != null && !emailids.isEmpty()){
						emailids = emailids +"," + elist;
					}else{
						emailids = elist;
					}
				}
				if(emailids != null && !emailids.isEmpty()){
					logger.info("Email list: "+ emailids);
					jobDescriptor.setMailAddress(emailids);
				}
			}
		} catch (ApplicationSettingsException e) {
			 logger.error("ApplicationSettingsException : " +e.getLocalizedMessage());
		}
        // this is common for any application descriptor
        jobDescriptor.setCallBackIp(ServerSettings.getIp());
        jobDescriptor.setCallBackPort(ServerSettings.getSetting(org.apache.airavata.common.utils.Constants.GFAC_SERVER_PORT, "8950"));
        jobDescriptor.setInputDirectory(jobExecutionContext.getInputDir());
        jobDescriptor.setOutputDirectory(jobExecutionContext.getOutputDir());
        jobDescriptor.setExecutablePath(jobExecutionContext.getExecutablePath());
        jobDescriptor.setStandardOutFile(jobExecutionContext.getStandardOutput());
        jobDescriptor.setStandardErrorFile(jobExecutionContext.getStandardError());
        String computationalProjectAccount = taskData.getTaskScheduling().getComputationalProjectAccount();
        taskData.getEmailAddresses();
        if (computationalProjectAccount == null){
            ComputeResourcePreference computeResourcePreference = jobExecutionContext.getApplicationContext().getComputeResourcePreference();
            if (computeResourcePreference != null) {
                computationalProjectAccount = computeResourcePreference.getAllocationProjectNumber();
            }
        }
        if (computationalProjectAccount != null) {
            jobDescriptor.setAcountString(computationalProjectAccount);
        }

        Random random = new Random();
        int i = random.nextInt(Integer.MAX_VALUE); // We always set the job name
        jobDescriptor.setJobName("A" + String.valueOf(i+99999999));
        jobDescriptor.setWorkingDirectory(jobExecutionContext.getWorkingDir());

        List inputValues = new ArrayList();
        MessageContext input = jobExecutionContext.getInMessageContext();
        // sort the inputs first and then build the command List
        Comparator inputOrderComparator = new Comparator() {
            @Override
            public int compare(InputDataObjectType inputDataObjectType, InputDataObjectType t1) {
                return inputDataObjectType.getInputOrder() - t1.getInputOrder();
            }
        };
        Set sortedInputSet = new TreeSet(inputOrderComparator);
        for (Object object : input.getParameters().values()) {
            if (object instanceof InputDataObjectType) {
                InputDataObjectType inputDOT = (InputDataObjectType) object;
                sortedInputSet.add(inputDOT);
            }
        }
        for (InputDataObjectType inputDataObjectType : sortedInputSet) {
            if (!inputDataObjectType.isRequiredToAddedToCommandLine()) {
                continue;
            }
            if (inputDataObjectType.getApplicationArgument() != null
                    && !inputDataObjectType.getApplicationArgument().equals("")) {
                inputValues.add(inputDataObjectType.getApplicationArgument());
            }

            if (inputDataObjectType.getValue() != null
                    && !inputDataObjectType.getValue().equals("")) {
                if (inputDataObjectType.getType() == DataType.URI) {
                    // set only the relative path
                    String filePath = inputDataObjectType.getValue();
                    filePath = filePath.substring(filePath.lastIndexOf(File.separatorChar) + 1, filePath.length());
                    inputValues.add(filePath);
                }else {
                    inputValues.add(inputDataObjectType.getValue());
                }

            }
        }

        Map outputParams = jobExecutionContext.getOutMessageContext().getParameters();
        for (Object outputParam : outputParams.values()) {
            if (outputParam instanceof OutputDataObjectType) {
                OutputDataObjectType output = (OutputDataObjectType) outputParam;
                if (output.getApplicationArgument() != null
                        && !output.getApplicationArgument().equals("")) {
                    inputValues.add(output.getApplicationArgument());
                }
                if (output.getValue() != null && !output.getValue().equals("") && output.isRequiredToAddedToCommandLine()) {
                    if (output.getType() == DataType.URI){
                        String filePath = output.getValue();
                        filePath = filePath.substring(filePath.lastIndexOf(File.separatorChar) + 1, filePath.length());
                        inputValues.add(filePath);
                    }
                }
            }
        }
        jobDescriptor.setInputValues(inputValues);

        jobDescriptor.setUserName(((GSISSHAbstractCluster) cluster).getServerInfo().getUserName());
        jobDescriptor.setShellName("/bin/bash");
        jobDescriptor.setAllEnvExport(true);
        jobDescriptor.setOwner(((PBSCluster) cluster).getServerInfo().getUserName());

        ComputationalResourceScheduling taskScheduling = taskData.getTaskScheduling();
        if (taskScheduling != null) {
            int totalNodeCount = taskScheduling.getNodeCount();
            int totalCPUCount = taskScheduling.getTotalCPUCount();

//        jobDescriptor.setJobSubmitter(applicationDeploymentType.getJobSubmitterCommand());
            if (taskScheduling.getComputationalProjectAccount() != null) {
                jobDescriptor.setAcountString(taskScheduling.getComputationalProjectAccount());
            }
            if (taskScheduling.getQueueName() != null) {
                jobDescriptor.setQueueName(taskScheduling.getQueueName());
            }

            if (totalNodeCount > 0) {
                jobDescriptor.setNodes(totalNodeCount);
            }
            if (taskScheduling.getComputationalProjectAccount() != null) {
                jobDescriptor.setAcountString(taskScheduling.getComputationalProjectAccount());
            }
            if (taskScheduling.getQueueName() != null) {
                jobDescriptor.setQueueName(taskScheduling.getQueueName());
            }
            if (totalCPUCount > 0) {
                int ppn = totalCPUCount / totalNodeCount;
                jobDescriptor.setProcessesPerNode(ppn);
                jobDescriptor.setCPUCount(totalCPUCount);
            }
            if (taskScheduling.getWallTimeLimit() > 0) {
                jobDescriptor.setMaxWallTime(String.valueOf(taskScheduling.getWallTimeLimit()));
                if(resourceJobManager.getResourceJobManagerType().equals(ResourceJobManagerType.LSF)){
                    jobDescriptor.setMaxWallTimeForLSF(String.valueOf(taskScheduling.getWallTimeLimit()));
                }
            }

            if (taskScheduling.getTotalPhysicalMemory() > 0) {
                jobDescriptor.setUsedMemory(taskScheduling.getTotalPhysicalMemory() + "");
            }
        } else {
            logger.error("Task scheduling cannot be null at this point..");
        }

        ApplicationDeploymentDescription appDepDescription = jobExecutionContext.getApplicationContext().getApplicationDeploymentDescription();
        List moduleCmds = appDepDescription.getModuleLoadCmds();
        if (moduleCmds != null) {
            for (String moduleCmd : moduleCmds) {
                jobDescriptor.addModuleLoadCommands(moduleCmd);
            }
        }
        List preJobCommands = appDepDescription.getPreJobCommands();
        if (preJobCommands != null) {
            for (String preJobCommand : preJobCommands) {
                jobDescriptor.addPreJobCommand(parseCommand(preJobCommand, jobExecutionContext));
            }
        }

        List postJobCommands = appDepDescription.getPostJobCommands();
        if (postJobCommands != null) {
            for (String postJobCommand : postJobCommands) {
                jobDescriptor.addPostJobCommand(parseCommand(postJobCommand, jobExecutionContext));
            }
        }

        ApplicationParallelismType parallelism = appDepDescription.getParallelism();
        if (parallelism != null){
            if (parallelism == ApplicationParallelismType.MPI || parallelism == ApplicationParallelismType.OPENMP || parallelism == ApplicationParallelismType.OPENMP_MPI){
                Map jobManagerCommands = resourceJobManager.getJobManagerCommands();
                if (jobManagerCommands != null && !jobManagerCommands.isEmpty()) {
                    for (JobManagerCommand command : jobManagerCommands.keySet()) {
                        if (command == JobManagerCommand.SUBMISSION) {
                            String commandVal = jobManagerCommands.get(command);
                            jobDescriptor.setJobSubmitter(commandVal);
                        }
                    }
                }
            }
        }
        return jobDescriptor;
    }

    private static String parseCommand(String value, JobExecutionContext jobExecutionContext) {
        String parsedValue = value.replaceAll("\\$workingDir", jobExecutionContext.getWorkingDir());
        parsedValue = parsedValue.replaceAll("\\$inputDir", jobExecutionContext.getInputDir());
        parsedValue = parsedValue.replaceAll("\\$outputDir", jobExecutionContext.getOutputDir());
        return parsedValue;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy