All Downloads are FREE. Search and download functionalities are using the official Maven repository.

prerna.engine.impl.model.responses.EmbeddingsModelEngineResponse Maven / Gradle / Ivy

The newest version!
package prerna.engine.impl.model.responses;

import java.util.List;
import java.util.Map;

import org.json.JSONObject;

public class EmbeddingsModelEngineResponse extends AbstractModelEngineResponse>> {
	
	/**
	 * 
	 */
	private static final long serialVersionUID = 4408956306133085964L;

	public EmbeddingsModelEngineResponse(List> response, Integer numberOfTokensInPrompt, Integer numberOfTokensInResponse) {
        super(response, numberOfTokensInPrompt, numberOfTokensInResponse);
    }

	@SuppressWarnings("unchecked")
	public static EmbeddingsModelEngineResponse fromMap(Map modelResponse) {
		Object responseObj = modelResponse.get(RESPONSE);
		if(responseObj instanceof String) {
			throw new IllegalArgumentException((String) responseObj);
		}
		List> responseObject = (List>) modelResponse.get(RESPONSE);
        Integer tokensInPrompt = getTokens(modelResponse.get(NUMBER_OF_TOKENS_IN_PROMPT));
        Integer tokensInResponse = getTokens(modelResponse.get(NUMBER_OF_TOKENS_IN_RESPONSE));
        
        return new EmbeddingsModelEngineResponse(responseObject, tokensInPrompt, tokensInResponse);
    }
	
	@SuppressWarnings("unchecked")
	public static EmbeddingsModelEngineResponse fromObject(Object responseObject) {
		Map modelResponse = (Map) responseObject;
		return fromMap(modelResponse);
    }
	
	public static EmbeddingsModelEngineResponse fromJson(JSONObject jsonResponse) {
	    if (jsonResponse == null) {
	        return null;
	    }
	    
	    List> embeddings = null;
	    Integer promptTokens = 0;
	    Integer responseTokens = 0;
	    
	    if (jsonResponse.has("output")) {
	        Object outputObj = jsonResponse.get("output");
	        
	        // Handle the case where output is a JSONArray containing a single JSON string
	        if (outputObj instanceof org.json.JSONArray) {
	            org.json.JSONArray outputArray = (org.json.JSONArray) outputObj;
	            if (outputArray.length() > 0) {
	                String jsonString = outputArray.getString(0);
	                try {
	                    JSONObject embeddingsJson = new JSONObject(jsonString);
	                    
	                    if (embeddingsJson.has("embeddings")) {
	                        org.json.JSONArray embeddingsArray = embeddingsJson.getJSONArray("embeddings");
	                        embeddings = new java.util.ArrayList<>();
	                        
	                        for (int i = 0; i < embeddingsArray.length(); i++) {
	                            org.json.JSONArray vectorArray = embeddingsArray.getJSONArray(i);
	                            List vector = new java.util.ArrayList<>();
	                            
	                            for (int j = 0; j < vectorArray.length(); j++) {
	                                vector.add(vectorArray.getDouble(j));
	                            }
	                            
	                            embeddings.add(vector);
	                        }
	                    }
	                } catch (Exception e) {
	                    throw new IllegalArgumentException("Failed to parse embeddings JSON: " + e.getMessage());
	                }
	            }
	        } else if (outputObj instanceof JSONObject) {
	            JSONObject embeddingsJson = (JSONObject) outputObj;
	            
	            if (embeddingsJson.has("embeddings")) {
	                org.json.JSONArray embeddingsArray = embeddingsJson.getJSONArray("embeddings");
	                embeddings = new java.util.ArrayList<>();
	                
	                for (int i = 0; i < embeddingsArray.length(); i++) {
	                    org.json.JSONArray vectorArray = embeddingsArray.getJSONArray(i);
	                    List vector = new java.util.ArrayList<>();
	                    
	                    for (int j = 0; j < vectorArray.length(); j++) {
	                        vector.add(vectorArray.getDouble(j));
	                    }
	                    
	                    embeddings.add(vector);
	                }
	            }
	        } else if (outputObj instanceof String) {
	            String jsonString = (String) outputObj;
	            try {
	                JSONObject embeddingsJson = new JSONObject(jsonString);
	                
	                if (embeddingsJson.has("embeddings")) {
	                    org.json.JSONArray embeddingsArray = embeddingsJson.getJSONArray("embeddings");
	                    embeddings = new java.util.ArrayList<>();
	                    
	                    for (int i = 0; i < embeddingsArray.length(); i++) {
	                        org.json.JSONArray vectorArray = embeddingsArray.getJSONArray(i);
	                        List vector = new java.util.ArrayList<>();
	                        
	                        for (int j = 0; j < vectorArray.length(); j++) {
	                            vector.add(vectorArray.getDouble(j));
	                        }
	                        
	                        embeddings.add(vector);
	                    }
	                }
	            } catch (Exception e) {
	                throw new IllegalArgumentException("Failed to parse embeddings JSON: " + e.getMessage());
	            }
	        }
	    }
	    
	    if (embeddings == null) {
	        embeddings = new java.util.ArrayList<>();
	    }
	    
	    if (jsonResponse.has("input_tokens")) {
	        promptTokens = jsonResponse.getInt("input_tokens");
	    }
	    
	    if (jsonResponse.has("output_tokens")) {
	        responseTokens = jsonResponse.getInt("output_tokens");
	    }
	    
	    return new EmbeddingsModelEngineResponse(embeddings, promptTokens, responseTokens);
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy