All Downloads are FREE. Search and download functionalities are using the official Maven repository.

prerna.reactor.model.NERReactor Maven / Gradle / Ivy

The newest version!
package prerna.reactor.model;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Vector;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import prerna.auth.utils.SecurityEngineUtils;
import prerna.engine.api.IModelEngine;
import prerna.engine.impl.model.NEREngine;
import prerna.reactor.AbstractReactor;
import prerna.sablecc2.om.GenRowStruct;
import prerna.sablecc2.om.PixelDataType;
import prerna.sablecc2.om.PixelOperationType;
import prerna.sablecc2.om.ReactorKeysEnum;
import prerna.sablecc2.om.nounmeta.NounMetadata;
import prerna.util.Utility;
import prerna.engine.impl.model.responses.NerModelEngineResponse;

public class NERReactor extends AbstractReactor {
	
	private static final Logger classLogger = LogManager.getLogger(NERReactor.class);
	
	public NERReactor() {
		this.keysToGet = new String[] {
				ReactorKeysEnum.ENGINE.getKey(),
				ReactorKeysEnum.PROMPT.getKey(),
				ReactorKeysEnum.ENTITIES.getKey(),
				ReactorKeysEnum.MASK_ENTITIES.getKey(),
				ReactorKeysEnum.PARAM_VALUES_MAP.getKey()
		};
		this.keyRequired = new int[] {1, 1, 1, 0, 0};
	}
	
	@Override
	public NounMetadata execute() {
		organizeKeys();
		String engineId = this.keyValue.get(this.keysToGet[0]);
		
		if(!SecurityEngineUtils.userCanViewEngine(this.insight.getUser(), engineId)) {
			throw new IllegalArgumentException("Model " + engineId + " does not exist or user does not have access to this model");
		}
		
		String prompt = Utility.decodeURIComponent(this.keyValue.get(this.keysToGet[1]));
		List entities = this.getListInput("entities");
		List maskEntities = this.getListInput("maskEntities");

		Map paramMap = getMap();
		if(paramMap == null) {
			paramMap = new HashMap();
		}
		
		// CASTING TO CORRECT ENGINE.. NER is not abstracted
		IModelEngine targetModel = Utility.getModel(engineId);
		NEREngine targetEngine = (NEREngine) targetModel;
		
		NerModelEngineResponse output = targetEngine.predict(prompt, entities, maskEntities, this.insight, paramMap);
		
		return new NounMetadata(output, PixelDataType.MAP, PixelOperationType.OPERATION);
		
	}

	private Map getMap() {
        GenRowStruct mapGrs = this.store.getNoun(keysToGet[4]);
        if(mapGrs != null && !mapGrs.isEmpty()) {
            List mapInputs = mapGrs.getNounsOfType(PixelDataType.MAP);
            if(mapInputs != null && !mapInputs.isEmpty()) {
                return (Map) mapInputs.get(0).getValue();
            }
        }
        List mapInputs = this.curRow.getNounsOfType(PixelDataType.MAP);
        if(mapInputs != null && !mapInputs.isEmpty()) {
            return (Map) mapInputs.get(0).getValue();
        }
        return null;
    }
	
	private List getListInput(String noun) {
		List colInputs = new Vector();
		GenRowStruct colGRS = this.store.getNoun(noun);
		if (colGRS != null) {
			for (int i = 0; i < colGRS.size(); i++) {
				String stringValue = colGRS.get(i).toString();
				colInputs.add(stringValue);
			}
		}
		return colInputs;
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy