
prerna.engine.impl.model.EmbedderKeywordExtractionReactor Maven / Gradle / Ivy
The newest version!
package prerna.engine.impl.model;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import prerna.auth.User;
import prerna.auth.utils.SecurityEngineUtils;
import prerna.engine.api.IModelEngine;
import prerna.reactor.AbstractReactor;
import prerna.sablecc2.om.GenRowStruct;
import prerna.sablecc2.om.PixelDataType;
import prerna.sablecc2.om.ReactorKeysEnum;
import prerna.sablecc2.om.nounmeta.NounMetadata;
import prerna.util.Utility;
public class EmbedderKeywordExtractionReactor extends AbstractReactor {
private static final Logger classLogger = LogManager.getLogger(EmbedderKeywordExtractionReactor.class);
private static final String PERCENTILE = "percentile";
public EmbedderKeywordExtractionReactor() {
this.keysToGet = new String[] {ReactorKeysEnum.MODEL.getKey(), ReactorKeysEnum.INPUT.getKey(),
PERCENTILE, ReactorKeysEnum.LIMIT.getKey()};
this.keyRequired = new int[] {1, 1, 0, 0};
}
@Override
public NounMetadata execute() {
organizeKeys();
String engineId = this.keyValue.get(this.keysToGet[0]);
User user = this.insight.getUser();
if (!SecurityEngineUtils.userCanViewEngine(user, engineId)) {
throw new IllegalArgumentException(
"Model " + engineId + " does not exist or user does not have access to this model");
}
IModelEngine modelE = Utility.getModel(engineId);
if( !(modelE instanceof EmbeddedModelEngine)) {
throw new IllegalArgumentException("This method only works for Local EmbeddedModelEngines");
}
String percentile = this.keyValue.get(PERCENTILE);
String limit = this.keyValue.get(this.keysToGet[3]);
EmbeddedModelEngine eme = (EmbeddedModelEngine) modelE;
Map parameters = new HashMap<>();
if(percentile != null && !(percentile=percentile.trim()).isEmpty()) {
parameters.put("percentile", ((Number) Double.parseDouble(percentile)).intValue());
}
if(limit != null && !(limit=limit.trim()).isEmpty()) {
parameters.put("max_keywords", ((Number) Double.parseDouble(limit)).intValue());
}
List input = getInput();
if(input.isEmpty()) {
throw new IllegalArgumentException("Must pass in list of inputs");
}
List decoded = new ArrayList<>(input.size());
for(int i = 0; i < input.size(); i++) {
decoded.add( Utility.decodeURIComponent(input.get(i)) );
}
List keywords = eme.keywordExtraction(decoded, insight, parameters);
return new NounMetadata(keywords, PixelDataType.VECTOR);
}
private List getInput() {
List columns = new ArrayList<>();
GenRowStruct colGrs = this.store.getNoun(this.keysToGet[1]);
if (colGrs != null && !colGrs.isEmpty()) {
for (int selectIndex = 0; selectIndex < colGrs.size(); selectIndex++) {
String column = colGrs.get(selectIndex) + "";
columns.add(column);
}
} else {
GenRowStruct inputsGRS = this.getCurRow();
// keep track of selectors to change to upper case
if (inputsGRS != null && !inputsGRS.isEmpty()) {
for (int selectIndex = 0; selectIndex < inputsGRS.size(); selectIndex++) {
String column = inputsGRS.get(selectIndex) + "";
columns.add(column);
}
}
}
return columns;
}
@Override
public String getReactorDescription() {
return "Utilizes a keyBERT model to extract the keywords from the text input";
}
@Override
protected String getDescriptionForKey(String key) {
if(key.equals(ReactorKeysEnum.INPUT.getKey())) {
return "The input array of string values to extract keywords from. Each string input will result in a space delimited list of keywords. "
+ "Each element in input should be encoded using for special character escaping";
} else if(key.equals(PERCENTILE)) {
return "The percentile (integer) cutoff for the words within the text to be considered a keyword. Values must be between 0 and 100 inclusive.";
} else if(key.equals(ReactorKeysEnum.LIMIT.getKey())) {
return "The limit to be applied after the percentile for the maximum number of keywords to be returned for each string input";
}
return super.getDescriptionForKey(key);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy