All Downloads are FREE. Search and download functionalities are using the official Maven repository.

prerna.engine.impl.model.TextGenerationProcessInference Maven / Gradle / Ivy

The newest version!
package prerna.engine.impl.model;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.io.FilenameUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import com.google.common.collect.ImmutableMap;

import prerna.sablecc2.om.execptions.SemossPixelException;
import prerna.util.Constants;
import prerna.util.PortAllocator;
import prerna.util.Utility;

public class TextGenerationProcessInference extends TextGenerationEngine {
	
	private static Logger classLogger = LogManager.getLogger(TextGenerationProcessInference.class);
	
	private HashMap launchArguments = new HashMap();
	private Process process;
    private String workerAddress;
    private String controllerAddress;
    private String inferencePort;
    private ImmutableMap possibleInputs = getLauncherArgs();
	
	@Override
	public void open(String smssFilePath) {
		try {
			if (smssFilePath != null) {
				classLogger.info("Loading Model - " + Utility.cleanLogString(FilenameUtils.getName(smssFilePath)));
				setSmssFilePath(smssFilePath);
				setSmssProp(Utility.loadProperties(smssFilePath));
			}
			for (String launcherArg : possibleInputs.keySet()) {
				String propArg = (String) smssProp.get(launcherArg);
				if(propArg != null && !propArg.isEmpty()){
					this.launchArguments.put(launcherArg,propArg);
				}
			}
	        if (!this.launchArguments.containsKey(Constants.MODEL)) {
	        	throw new IllegalArgumentException("Model name is a required argument.");
	        } 
	        
	        if (!this.launchArguments.containsKey("PORT")) {
	        	this.inferencePort=PortAllocator.getInstance().getNextAvailablePort()+"";
	        	launchArguments.put("PORT", this.inferencePort);
	        } else {
	        	this.inferencePort=this.launchArguments.get("PORT");
	        }
			
			
			this.workerAddress = Utility.getDIHelperProperty(Constants.WORKER_ADDRESS);
			if (this.workerAddress == null || this.workerAddress.trim().isEmpty()) {
				this.workerAddress = System.getenv(Constants.WORKER_ADDRESS);
			}
			if (this.controllerAddress ==null || this.controllerAddress.trim().isEmpty()) {
				this.controllerAddress = this.workerAddress + ":" + this.inferencePort;
			}
			
			if (!this.smssProp.containsKey("ENDPOINT")) {
				this.smssProp.put("ENDPOINT", this.controllerAddress);
			} 
			
			// create a generic folder
			this.workingDirectory = "EM_MODEL_" + Utility.getRandomString(6);
			this.workingDirectoryBasePath = Utility.getInsightCacheDir() + "/" + workingDirectory;
			this.cacheFolder = new File(workingDirectoryBasePath);
			
			// make the folder if one does not exist
			if(!cacheFolder.exists())
				cacheFolder.mkdir();
			
			// vars for string substitution
			for (Object smssKey : this.smssProp.keySet()) {
				String key = smssKey.toString();
				this.vars.put(key, this.smssProp.getProperty(key));
			}
		} catch(Exception e) {
			classLogger.error(Constants.STACKTRACE, e);
			throw new SemossPixelException("Unable to load model details from the SMSS file");
		}
	}
	

	@Override
	public void startServer(int port) {
		List command = new ArrayList<>();
		if (this.launchArguments.containsKey(Constants.GPU_ID)) {
    		command.add(possibleInputs.get(Constants.GPU_ID) + "=" + this.launchArguments.get(Constants.GPU_ID));
        	this.launchArguments.remove(Constants.GPU_ID);
		}
		command.add("text-generation-launcher");
		for (String arg : this.launchArguments.keySet()) {
    		command.add(possibleInputs.get(arg));
        	command.add(this.launchArguments.get(arg));
		}
		
		System.out.println("Executing command: " + String.join(" ", command));
		
		ProcessBuilder processBuilder = new ProcessBuilder(command);
		// dont inherit IO so that we can catch 
        // processBuilder.inheritIO(); // Redirect the subprocess's standard error and output to the current process
       
        try {
			this.process = processBuilder.start();
		} catch (IOException e) {
			// TODO Auto-generated catch block
			classLogger.error(Constants.STACKTRACE, e);
		}
        
        // Wait for the process to finish loading
        // Read the output of the process
        String logPattern = "\\d{4}-\\d{2}-\\d{2}T\\d{2}:\\d{2}:\\d{2}\\.\\d{6}Z  INFO.*Connected";
        Pattern pattern = Pattern.compile(logPattern);
        // Read the output log file in real-time
        try (InputStream inputStream = process.getInputStream();
        		BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
            String line;
            while ((line = reader.readLine()) != null) {
                // Check if the line matches the log pattern
            	classLogger.info(line);
                Matcher matcher = pattern.matcher(line);
                if (matcher.find() || line.endsWith("Connected")) {
                    // Process the matching log line, if needed
                	classLogger.info("Process has finished loading");
                    break;
                }
            }
        } catch (IOException e) {
        	classLogger.error(Constants.STACKTRACE, e);
		}
        
        super.startServer(-1);
	}
	
	@Override
	public void close() {
        if (process != null) {
            // Attempt to gracefully shut down the Python process first
            process.destroy();

            try {
                // Create a separate thread to wait for the process to exit with a timeout of 5 seconds
                Thread waitThread = new Thread(() -> {
                    try {
                        process.waitFor();
                    } catch (InterruptedException e) {
                        classLogger.error(Constants.STACKTRACE, e);
                    }
                });

                // Start the thread to wait for the process
                waitThread.start();

                // Wait for the thread to complete with a timeout of 5 seconds
                waitThread.join(5000);

                // Optionally, you can try to forcibly terminate the process if it's still running after graceful shutdown
                if (process.isAlive()) {
                    process.destroyForcibly();
                    System.out.println("Process forcefully terminated.");
                }

            } catch (InterruptedException e) {
                classLogger.error(Constants.STACKTRACE, e);
            }
        }
        try {
			super.close();
		} catch (IOException e) {
			classLogger.error(Constants.STACKTRACE, e);
		}
	}
	



    private static ImmutableMap getLauncherArgs() {
    	return new ImmutableMap.Builder()
    			.put(Constants.GPU_ID,"CUDA_VISIBLE_DEVICES")
    			.put(Constants.MODEL,"--model-id")
    			.put("REVISION","--revision")
    			.put("SHARDED","--sharded")
    			.put(Constants.NUM_GPU,"--num-shard")
    			.put("QUANTIZE","--quantize")
    			.put("TRUST_REMOTE_CODE","--trust-remote-code")
    			.put("MAX_CONCURRENT_REQUESTS","--max-concurrent-requests")
    			.put("MAX_BEST_OF","--max-best-of")
    			.put("MAX_STOP_SEQUENCES","--max-stop-sequences")
    			.put("MAX_INPUT_LENGTH","--max-input-length")
    			.put("MAX_TOTAL_TOKENS","--max-total-tokens")
    			.put("MAX_BATCH_SIZE","--max-batch-size")
    			.put("WAITING_SERVED_RATIO","--waiting-served-ratio")
    			.put("MAX_BATCH_TOTAL_TOKENS","--max-batch-total-tokens")
    			.put("MAX_WAITING_TOKENS","--max-waiting-tokens")
    			.put("PORT","--port")
    			.put("SHARD_UDS_PATH","--shard-uds-path")
    			.put("MASTER_ADDR","--master-addr")
    			.put("MASTER_PORT","--master-port")
    			.put("HUGGINGFACE_HUB_CACHE","--huggingface-hub-cache")
    			.put("WEIGHTS_CACHE_OVERRIDE","--weights-cache-override")
    			.put("DISABLE_CUSTOM_KERNERLS","--disable-custom-kernels")
    			.put("JSON_OUTPUT","--json-output")
    			.put("OTLP_ENDPOINT","--otlp-endpoint")
    			.put("CORS_ALLOW_ORIGIN","--cors-allow-origin")
    			.put("WATERMARK_GAMMA","--watermark-gamma")
    			.put("WATERMARK_DELTA","--watermark-delta")
    			.build();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy