All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.moderne.ai.AgentGenerativeModelClient Maven / Gradle / Ivy

There is a newer version: 0.19.1
Show newest version
/*
 * Copyright 2021 the original author or authors.
 * 

* Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at *

* https://www.apache.org/licenses/LICENSE-2.0 *

* Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.moderne.ai; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.cfg.ConstructorDetector; import com.fasterxml.jackson.databind.json.JsonMapper; import com.fasterxml.jackson.module.paramnames.ParameterNamesModule; import kong.unirest.HttpResponse; import kong.unirest.Unirest; import kong.unirest.UnirestException; import lombok.Value; import org.openrewrite.internal.lang.Nullable; import org.openrewrite.ipc.http.HttpSender; import org.openrewrite.ipc.http.HttpUrlConnectionSender; import java.io.*; import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.regex.Matcher; import java.util.regex.Pattern; public class AgentGenerativeModelClient { @Nullable private static AgentGenerativeModelClient INSTANCE; private static HashMap methodsToSample; private final ObjectMapper mapper = JsonMapper.builder() .constructorDetector(ConstructorDetector.USE_PROPERTIES_BASED) .build() .registerModule(new ParameterNamesModule()) .disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES); private static final ExecutorService EXECUTOR_SERVICE = Executors.newFixedThreadPool(3); static String pathToModel = "/MODELS/codellama.gguf"; static String pathToLLama = "/app/llama.cpp"; static String pathToFiles = "/app/"; static String port = "7878"; public static synchronized AgentGenerativeModelClient getInstance() { if (INSTANCE == null) { //Check if llama.cpp is already built File f = new File(pathToLLama + "/main"); if (!(f.exists() && !f.isDirectory())) { //Build llama.cpp StringWriter sw = new StringWriter(); PrintWriter procOut = new PrintWriter(sw); try { Runtime runtime = Runtime.getRuntime(); Process proc_make = runtime.exec(new String[]{"/bin/sh", "-c", "make -C " + pathToLLama}); proc_make.waitFor(); new BufferedReader(new InputStreamReader(proc_make.getInputStream())).lines() .forEach(procOut::println); new BufferedReader(new InputStreamReader(proc_make.getErrorStream())).lines() .forEach(procOut::println); if (proc_make.exitValue() != 0) { throw new RuntimeException("Failed to make llama.cpp at " + pathToLLama + "\n" + sw); } } catch (IOException | InterruptedException e) { throw new RuntimeException(e + "\nOutput: " + sw); } } INSTANCE = new AgentGenerativeModelClient(); //Start server if (INSTANCE.checkForUpRequest() != 200) { StringWriter sw = new StringWriter(); PrintWriter procOut = new PrintWriter(sw); try { Runtime runtime = Runtime.getRuntime(); Process proc_server = runtime.exec((new String[] {"/bin/sh", "-c", pathToLLama + "/server -m " + pathToModel + " --port " + port})); EXECUTOR_SERVICE.submit(() -> { new BufferedReader(new InputStreamReader(proc_server.getInputStream())).lines() .forEach(procOut::println); new BufferedReader(new InputStreamReader(proc_server.getErrorStream())).lines() .forEach(procOut::println); }); if (!INSTANCE.checkForUp()) { throw new RuntimeException("Failed to start server\n" + sw); } } catch (IOException e) { throw new RuntimeException(e + "\nOutput: " + sw); } } return INSTANCE; } return INSTANCE; } private int checkForUpRequest() { try { HttpResponse response = Unirest.head("http://127.0.0.1:" + port).asString(); return response.getStatus(); } catch (UnirestException e) { return 523; } } private boolean checkForUp() { for (int i = 0; i < 60; i++) { try { if (checkForUpRequest() == 200) { return true; } Thread.sleep(1_000); } catch (InterruptedException e) { throw new RuntimeException(e); } } return false; } public static void populateMethodsToSample(String pathToCenters) { HashMap tempMethodsToSample = new HashMap<>(); try (BufferedReader bufferedReader = new BufferedReader(new FileReader(pathToCenters))) { String line; String source; String methodCall; while ((line = bufferedReader.readLine()) != null) { source = line.split(" ")[0]; methodCall = line.split(" ")[1]; tempMethodsToSample.put(source, methodCall); } } catch (IOException e) { throw new RuntimeException("Couldn't read which methods to sample. " + e); } methodsToSample = tempMethodsToSample; } public static HashMap getMethodsToSample() { return methodsToSample; } public ArrayList getRecommendations(String code) { try ( BufferedReader bufferedReader = new BufferedReader(new FileReader(pathToFiles + "prompt.txt")) ) { // Write a temporary file for input which includes prompt and relevant code snippet String line; StringBuilder promptContent = new StringBuilder(); while ((line = bufferedReader.readLine()) != null) { promptContent.append(line).append("\n"); } String text = "[INST]" + promptContent + code + "```\n[/INST]1."; HttpSender http = new HttpUrlConnectionSender(Duration.ofSeconds(20), Duration.ofSeconds(90)); HttpSender.Response raw; HashMap input = new HashMap<>(); input.put("stream", false); input.put("prompt", text); input.put("temperature", 0.5); input.put("n_predict", 150); try { raw = http .post("http://127.0.0.1:" + port + "/completion") .withContent("application/json", mapper.writeValueAsBytes(input)).send(); } catch (JsonProcessingException e) { throw new RuntimeException(e); } if (!raw.isSuccessful()) { throw new IllegalStateException("Unable to get recommendations. HTTP " + raw.getClass()); } String textResponse; textResponse = mapper.readValue(raw.getBodyAsBytes(), LlamaResponse.class).getResponse(); ArrayList recommendations = parseRecommendations("1." + textResponse); if (recommendations.isEmpty()) { BufferedReader bufferedReaderLog = new BufferedReader(new FileReader(pathToFiles + "llama_log.txt")); String logLine; StringBuilder logContent = new StringBuilder(); while ((logLine = bufferedReaderLog.readLine()) != null) { logContent.append(logLine).append("\n"); } bufferedReaderLog.close(); throw new RuntimeException("Logs: " + logContent); } return recommendations; } catch (IOException e) { throw new RuntimeException(e); } } public ArrayList parseRecommendations(String recommendations) { if (recommendations.equals("[]")) { return new ArrayList<>(); } else { String patternString = "\\b\\d+[.:\\-]\\s+(.*?)\\s*(?=\\b\\d+[.:\\-]|\\Z)"; Pattern pattern = Pattern.compile(patternString, Pattern.DOTALL); Matcher matcher = pattern.matcher(recommendations); ArrayList matches = new ArrayList<>(); while (matcher.find()) { matches.add(matcher.group(1)); } return matches; } } public boolean isRelated(String query, String code, double threshold) { String promptContent = "Does this query match the code snippet?\n"; promptContent += "Query: " + query + "\n"; promptContent += "Code: " + code + "\n"; promptContent += "Answer as 'ANS: Yes' or 'ANS: No'.\n"; promptContent = "[INST]" + promptContent + "[/INST]ANS:"; HttpSender http = new HttpUrlConnectionSender(Duration.ofSeconds(20), Duration.ofSeconds(60)); HttpSender.Response raw; HashMap input = new HashMap<>(); input.put("stream", false); input.put("prompt", promptContent); input.put("temperature", 0.0); input.put("n_predict", 3); input.put("n_probs", 10); try { raw = http .post("http://127.0.0.1:" + port + "/completion") .withContent("application/json", mapper.writeValueAsBytes(input)).send(); } catch (JsonProcessingException e) { throw new RuntimeException(e); } if (!raw.isSuccessful()) { throw new IllegalStateException("Unable to get response from server. HTTP " + raw.getClass()); } boolean relatedResponse; try { relatedResponse = mapper.readValue(raw.getBodyAsBytes(), LlamaResponseProbabilities.class).isRelated(threshold); } catch (IOException e) { throw new RuntimeException(e); } return relatedResponse; } private boolean parseRelated(String s) { // check if Yes or yes in output, return true if it does return (s.contains("Yes") || s.contains("yes")); } @Value private static class LlamaResponse { String content; public String getResponse() { return content; } } @Value public static class LlamaResponseProbabilities { List completionProbabilities; String content; boolean multimodal; int slotId; boolean stop; @JsonProperty("completion_probabilities") public List getCompletionProbabilities() { return completionProbabilities; } public boolean isRelated(double threshold) { for (CompletionProbability cp : completionProbabilities) { if (cp.getContent().equals(" Yes")) { return cp.getProbs().get(0).getProb() >= threshold; } } return false; } } @Value public static class CompletionProbability { String content; List probs; public String getContent() { return content; } public List getProbs() { return probs; } } @Value public static class TokenProbability { double prob; String tokStr; public double getProb() { return prob; } public String getTokStr() { return tokStr; } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy