All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.moderne.ai.research.FindCodeThatResembles Maven / Gradle / Ivy

There is a newer version: 0.21.0
Show newest version
/*
 * Copyright 2021 the original author or authors.
 * 

* Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at *

* https://www.apache.org/licenses/LICENSE-2.0 *

* Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.moderne.ai.research; import io.moderne.ai.AgentGenerativeModelClient; import io.moderne.ai.EmbeddingModelClient; import io.moderne.ai.RelatedModelClient; import io.moderne.ai.table.CodeSearch; import io.moderne.ai.table.EmbeddingPerformance; import io.moderne.ai.table.GenerativeModelPerformance; import io.moderne.ai.table.SuggestedMethodPatterns; import io.moderne.ai.table.TopKMethodMatcher; import lombok.EqualsAndHashCode; import lombok.RequiredArgsConstructor; import lombok.Value; import lombok.experimental.NonFinal; import org.jspecify.annotations.Nullable; import org.openrewrite.*; import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.MethodMatcher; import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.JavaSourceFile; import org.openrewrite.marker.SearchResult; import java.time.Duration; import java.util.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import static java.util.Objects.requireNonNull; @Value @EqualsAndHashCode(callSuper = false) public class FindCodeThatResembles extends ScanningRecipe { @Option(displayName = "Resembles", description = "The text, either a natural language description or a code sample, " + "that you are looking for.", example = "HTTP request with Content-Type application/json") String resembles; @Option(displayName = "Top-K methods", description = "Since AI based matching has a higher latency than rules based matching, " + "we do a first pass to find the top k methods using embeddings. " + "To narrow the scope, you can specify the top k methods as method filters.", example = "5") int k; transient CodeSearch codeSearchTable = new CodeSearch(this); transient TopKMethodMatcher topKTable = new TopKMethodMatcher(this); transient EmbeddingPerformance embeddingPerformance = new EmbeddingPerformance(this); transient GenerativeModelPerformance generativeModelPerformance = new GenerativeModelPerformance(this); transient SuggestedMethodPatterns suggestedMethodPatternsTable = new SuggestedMethodPatterns(this); @Override public String getDisplayName() { return "Find method invocations that resemble a pattern"; } @Override public String getDescription() { return "This recipe uses two phase AI approach to find a method invocation" + " that resembles a search string."; } @Value private static class MethodSignatureWithDistance { String methodSignature; String methodPattern; double distance; } @Value @RequiredArgsConstructor public static class Accumulator { @NonFinal @Nullable Boolean populatedTopKDataTable = false; final int k; PriorityQueue methodSignaturesQueue = new PriorityQueue<>(Comparator.comparingDouble(MethodSignatureWithDistance::getDistance)); EmbeddingModelClient embeddingModelClient = EmbeddingModelClient.getInstance(); private HashSet methodPatternsSet = new HashSet<>(); @NonFinal @Nullable List topMethodPatterns; @NonFinal @Nullable List topMethodSignatureWithDistances; public void add(String methodSignature, String methodPattern, String resembles) { if (methodPatternsSet.contains(methodPattern)) { return; } MethodSignatureWithDistance methodSignatureWithDistance = new MethodSignatureWithDistance( methodSignature, methodPattern, (float) embeddingModelClient.getDistance(resembles, methodSignature) ); methodSignaturesQueue.add(methodSignatureWithDistance); methodPatternsSet.add(methodPattern); } public List getTopMethodSignatureWithDistances() { return topMethodSignatureWithDistances; } public List getTopMethodPatterns() { return topMethodPatterns; } public @Nullable List populateTopK() { if (topMethodPatterns != null) { return null; } topMethodPatterns = new ArrayList<>(k); topMethodSignatureWithDistances = new ArrayList<>(k); for (int i = 0; i < k && !methodSignaturesQueue.isEmpty(); i++) { MethodSignatureWithDistance currentMethod = methodSignaturesQueue.poll(); String inputString = currentMethod.getMethodPattern(); if (!inputString.contains("")){ inputString = inputString.replaceAll("<[^>]*>", ""); } topMethodPatterns.add(new MethodMatcher(inputString, true)); topMethodSignatureWithDistances.add(currentMethod); } return topMethodPatterns; } public void setPopulatedTopKDataTable(boolean b) { this.populatedTopKDataTable = b; } } @Override public Accumulator getInitialValue(ExecutionContext ctx) { return new Accumulator(k); } @Override public TreeVisitor getScanner(Accumulator acc) { return new JavaIsoVisitor() { private String extractTypeName(String fullyQualifiedTypeName) { // Split around the '<' and '>' while keeping them for re-insertion String[] parts = fullyQualifiedTypeName.split("(<|>)"); String outer = parts[0]; String inner = parts.length > 1 ? parts[1] : ""; outer = outer.substring(outer.lastIndexOf('.') + 1); inner = inner.substring(inner.lastIndexOf('.') + 1); return inner.isEmpty() ? outer : outer + "<" + inner + ">"; } @SuppressWarnings("OptionalOfNullableMisuse") @Override public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) { cu.getTypesInUse().getUsedMethods().forEach(type -> { StringBuilder methodSignatureBuilder = new StringBuilder(); StringBuilder methodPatternBuilder = new StringBuilder(); String methodSignature = methodSignatureBuilder.append(extractTypeName(Optional.ofNullable(type.getReturnType()) .map(Object::toString).orElse(""))).append(" ").append(type.getName()).toString(); methodSignatureBuilder.setLength(0); // Clear the builder for reuse for (int i = 0; i < type.getParameterTypes().size(); i++) { String typeName = extractTypeName(type.getParameterTypes().get(i).toString()); String paramName = type.getParameterNames().get(i); methodSignatureBuilder.append(typeName).append(" ").append(paramName); if (i < type.getParameterTypes().size() - 1) { methodSignatureBuilder.append(", "); } } methodSignature += "(" + methodSignatureBuilder.toString() + ")"; methodPatternBuilder.setLength(0); // Clear the builder for reuse String methodPattern = methodPatternBuilder.append(Optional.ofNullable(type.getDeclaringType()) .map(Object::toString).orElse("")).append(" ").append(type.getName()).append("(..)").toString(); acc.add(methodSignature, methodPattern, resembles); }); return super.visitCompilationUnit(cu, ctx); } }; } @Override public TreeVisitor getVisitor(Accumulator acc) { acc.populateTopK(); List methodMatchers = acc.getTopMethodPatterns(); List> preconditions = new ArrayList<>(methodMatchers.size()); for (MethodMatcher m : methodMatchers) { preconditions.add(new UsesMethod<>(m)); } //noinspection unchecked return Preconditions.check(Preconditions.or(preconditions.toArray(new TreeVisitor[0])), new JavaIsoVisitor() { @Override public boolean isAcceptable(SourceFile sourceFile, ExecutionContext ctx) { return sourceFile instanceof J.CompilationUnit; } @Override public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) { getCursor().putMessage("countEmbedding", new AtomicInteger()); getCursor().putMessage("maxEmbedding", new AtomicLong()); getCursor().putMessage("histogramEmbedding", new EmbeddingPerformance.Histogram()); getCursor().putMessage("countGenerative", new AtomicInteger()); getCursor().putMessage("maxGenerative", new AtomicLong()); getCursor().putMessage("histogramGenerative", new GenerativeModelPerformance.Histogram()); try { return super.visitCompilationUnit(cu, ctx); } finally { if (getCursor().getMessage("countEmbedding", new AtomicInteger()).get() > 0) { Duration embeddingMax = Duration.ofNanos(requireNonNull(getCursor().getMessage("maxEmbedding")).get()); embeddingPerformance.insertRow(ctx, new EmbeddingPerformance.Row(( (SourceFile) cu).getSourcePath().toString(), requireNonNull(getCursor().getMessage("countEmbedding")).get(), requireNonNull(getCursor().getMessage("histogramEmbedding")).getBuckets(), embeddingMax)); } if (getCursor().getMessage("countGenerative", new AtomicInteger()).get() > 0) { Duration generativeMax = Duration.ofNanos(requireNonNull(getCursor().getMessage("maxGenerative")).get()); generativeModelPerformance.insertRow(ctx, new GenerativeModelPerformance.Row(( (SourceFile) cu).getSourcePath().toString(), requireNonNull(getCursor().getMessage("countGenerative")).get(), requireNonNull(getCursor().getMessage("histogramGenerative")).getBuckets(), generativeMax)); } } } @Override public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { if (!acc.populatedTopKDataTable) { List methodMatchersDistance = acc.getTopMethodSignatureWithDistances(); for (MethodSignatureWithDistance methodSignatureWithDistance : methodMatchersDistance) { topKTable.insertRow(ctx, new TopKMethodMatcher.Row( methodSignatureWithDistance.getMethodPattern(), methodSignatureWithDistance.getMethodSignature(), methodSignatureWithDistance.getDistance(), resembles )); } acc.setPopulatedTopKDataTable(true); } boolean matches = false; String methodPattern = ""; for (MethodMatcher m : methodMatchers) { if (m.matches(method)) { matches = true; methodPattern = m.toString(); break; } } if (!matches) { return super.visitMethodInvocation(method, ctx); } RelatedModelClient.Relatedness related = RelatedModelClient.getInstance() .getRelatedness(resembles, method.printTrimmed(getCursor())); for (Duration timing : related.getEmbeddingTimings()) { requireNonNull(getCursor().getNearestMessage("countEmbedding")).incrementAndGet(); requireNonNull(getCursor().getNearestMessage("histogramEmbedding")).add(timing); AtomicLong max = getCursor().getNearestMessage("maxEmbedding"); if (requireNonNull(max).get() < timing.toNanos()) { max.set(timing.toNanos()); } } int resultEmbeddingModels = related.isRelated(); // results from two first models -1, 0, 1 boolean calledGenerativeModel = false; boolean resultGenerativeModel = false; if (resultEmbeddingModels == 0) { AgentGenerativeModelClient.TimedRelatedness resultGenerativeModelTimed = AgentGenerativeModelClient.getInstance() .isRelatedTiming(resembles, method.printTrimmed(getCursor()), 0.413); resultGenerativeModel = resultGenerativeModelTimed.isRelated(); calledGenerativeModel = true; Duration timing = resultGenerativeModelTimed.getDuration(); requireNonNull(getCursor().getNearestMessage("countGenerative")).incrementAndGet(); requireNonNull(getCursor().getNearestMessage("histogramGenerative")).add(timing); AtomicLong max = getCursor().getNearestMessage("maxGenerative"); if (requireNonNull(max).get() < timing.toNanos()) { max.set(timing.toNanos()); } } // Populate data table for debugging model's accuracy JavaSourceFile javaSourceFile = getCursor().firstEnclosing(JavaSourceFile.class); String source = javaSourceFile.getSourcePath().toString(); codeSearchTable.insertRow(ctx, new CodeSearch.Row( source, method.printTrimmed(getCursor()), resembles, resultEmbeddingModels, calledGenerativeModel, resultGenerativeModel )); if (resultGenerativeModel || resultEmbeddingModels == 1) { suggestedMethodPatternsTable.insertRow(ctx, new SuggestedMethodPatterns.Row( method.printTrimmed(getCursor()), methodPattern, resembles )); } if (calledGenerativeModel){ return resultGenerativeModel ? SearchResult.found(method) : super.visitMethodInvocation(method, ctx); } else { return resultEmbeddingModels == 1 ? SearchResult.found(method) : super.visitMethodInvocation(method, ctx); } } }); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy