
prerna.engine.impl.model.TextEmbeddingsEngine Maven / Gradle / Ivy
The newest version!
package prerna.engine.impl.model;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.apache.http.entity.ContentType;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import prerna.engine.api.ModelTypeEnum;
import prerna.engine.impl.model.responses.AskModelEngineResponse;
import prerna.engine.impl.model.responses.AskStringModelEngineResponse;
import prerna.engine.impl.model.responses.InstructModelEngineResponse;
import prerna.engine.impl.model.responses.EmbeddingsModelEngineResponse;
import prerna.om.Insight;
import prerna.security.HttpHelperUtility;
import prerna.util.Constants;
public class TextEmbeddingsEngine extends AbstractRESTModelEngine {
private static final Logger classLogger = LogManager.getLogger(TextEmbeddingsEngine.class);
private static final String ENDPOINT = "ENDPOINT";
private static final String BATCH_SIZE = "BATCH_SIZE";
private int batchSize;
private String endpoint;
@Override
public void open(Properties smssProp) throws Exception {
super.open(smssProp);
this.endpoint = this.smssProp.getProperty(ENDPOINT);
if(this.endpoint == null || (this.endpoint=this.endpoint.trim()).isEmpty()) {
throw new IllegalArgumentException("This model requires a valid value for " + ENDPOINT);
}
// Utility.checkIfValidDomain(this.endpoint);
this.batchSize = 32;
String batchSizeStr = null;
try {
batchSizeStr = this.smssProp.getProperty(BATCH_SIZE);
if(batchSizeStr != null && !(batchSizeStr=batchSizeStr.trim()).isEmpty()) {
this.batchSize = Integer.valueOf(batchSizeStr);
}
} catch(NumberFormatException e) {
classLogger.error(Constants.STACKTRACE, e);
throw new IllegalArgumentException("The SMSS has an invalid value for " + BATCH_SIZE +". Must be an integer but found " + batchSizeStr);
}
}
@Override
public EmbeddingsModelEngineResponse embeddingsCall(List stringsToEncode, Insight insight, Map parameters) {
List> embeddings = new ArrayList<>();
List> sentenceSublists = new ArrayList<>();
for (int i = 0; i < stringsToEncode.size(); i += batchSize) {
int endIndex = Math.min(i + batchSize, stringsToEncode.size());
List sublist = stringsToEncode.subList(i, endIndex);
sentenceSublists.add(sublist);
}
for(List sublist : sentenceSublists) {
Map bodyMap = new HashMap<>();
bodyMap.put("inputs", sublist);
bodyMap.put("truncate", true);
String output = HttpHelperUtility.postRequestStringBody(this.endpoint, null, new Gson().toJson(bodyMap), ContentType.APPLICATION_JSON, null, null, null);
List> outputParsed = new Gson().fromJson(output, new TypeToken>>() {}.getType());
embeddings.addAll(outputParsed);
}
EmbeddingsModelEngineResponse embeddingsResponse = new EmbeddingsModelEngineResponse(embeddings, 0, 0);
return embeddingsResponse;
}
@Override
public EmbeddingsModelEngineResponse imageEmbeddingsCall(List imagesToEmbed, Insight insight, Map parameters) {
// TODO Auto-generated method stub
return null;
}
@Override
protected AskModelEngineResponse askCall(String question, Object fullPrompt, String context, Insight insight, Map parameters) {
return new AskStringModelEngineResponse("This model does not support text generation.", 0, 0);
}
@Override
public ModelTypeEnum getModelType() {
return ModelTypeEnum.TEXT_EMBEDDINGS;
}
@Override
protected void resetAfterTimeout() {
// nothing to reset currently
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy