All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.moderne.ai.research.FindCodeThatResembles Maven / Gradle / Ivy

There is a newer version: 0.24.1
Show newest version
/*
 * Copyright 2021 the original author or authors.
 * 

* Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at *

* https://www.apache.org/licenses/LICENSE-2.0 *

* Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.moderne.ai.research; import io.moderne.ai.AgentGenerativeModelClient; import io.moderne.ai.EmbeddingModelClient; import io.moderne.ai.RelatedModelClient; import io.moderne.ai.table.CodeSearch; import io.moderne.ai.table.EmbeddingPerformance; import lombok.EqualsAndHashCode; import lombok.RequiredArgsConstructor; import lombok.Value; import lombok.experimental.NonFinal; import org.openrewrite.*; import org.openrewrite.internal.lang.Nullable; import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.MethodMatcher; import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.JavaSourceFile; import org.openrewrite.marker.SearchResult; import java.time.Duration; import java.util.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import static java.util.Objects.requireNonNull; @Value @EqualsAndHashCode(callSuper = false) public class FindCodeThatResembles extends ScanningRecipe { @Option(displayName = "Resembles", description = "The text, either a natural language description or a code sample, " + "that you are looking for.", example = "HTTP request with Content-Type application/json") String resembles; @Option(displayName = "top k methods", description = "Since AI based matching has a higher latency than rules based matching, " + "we do a first pass to find the top k methods using embeddings. " + "To narrow the scope, you can specify the top k methods as method filters.", example = "1000") int k; transient CodeSearch codeSearchTable = new CodeSearch(this); transient EmbeddingPerformance performance = new EmbeddingPerformance(this); @Override public String getDisplayName() { return "Find method invocations that resemble a pattern"; } @Override public String getDescription() { return "This recipe uses two phase AI approach to find a method invocation" + " that resembles a search string."; } @Value private static class MethodSignatureWithDistance { String methodSignature; String methodPattern; double distance; } @Value @RequiredArgsConstructor public static class Accumulator { int k; PriorityQueue methodSignaturesQueue = new PriorityQueue<>(Comparator.comparingDouble(MethodSignatureWithDistance::getDistance)); EmbeddingModelClient embeddingModelClient = EmbeddingModelClient.getInstance(); @NonFinal @Nullable List topMethodPatterns; public void add(String methodSignature, String methodPattern, String resembles) { for (MethodSignatureWithDistance entry : methodSignaturesQueue) { if (entry.methodPattern.equals(methodPattern)) { return; } } MethodSignatureWithDistance methodSignatureWithDistance = new MethodSignatureWithDistance(methodSignature, methodPattern, (float) embeddingModelClient.getDistance(resembles, methodSignature)); methodSignaturesQueue.add(methodSignatureWithDistance); } public List getMethodMatchersTopK() { if (topMethodPatterns != null) { return topMethodPatterns; } topMethodPatterns = new ArrayList<>(k); for (int i = 0; i < k && !methodSignaturesQueue.isEmpty(); i++) { topMethodPatterns.add(new MethodMatcher(methodSignaturesQueue.poll().getMethodPattern(), true)); } return topMethodPatterns; } } @Override public Accumulator getInitialValue(ExecutionContext ctx) { return new Accumulator(k); } @Override public TreeVisitor getScanner(Accumulator acc) { return new JavaIsoVisitor() { private String extractTypeName(String fullyQualifiedTypeName) { return fullyQualifiedTypeName.replace("<.*>", "") .substring(fullyQualifiedTypeName.lastIndexOf('.') + 1); } @SuppressWarnings("OptionalOfNullableMisuse") @Override public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) { cu.getTypesInUse().getUsedMethods().forEach(type -> { String methodSignature = extractTypeName(Optional.ofNullable(type.getReturnType()) .map(Object::toString).orElse("")) + " " + type.getName(); String[] parameters = new String[type.getParameterTypes().size()]; for (int i = 0; i < type.getParameterTypes().size(); i++) { String typeName = extractTypeName(type.getParameterTypes().get(i).toString()); String paramName = type.getParameterNames().get(i); parameters[i] = typeName + " " + paramName; } methodSignature += "(" + String.join(", ", parameters) + ")"; String methodPattern = Optional.ofNullable(type.getDeclaringType()).map(Object::toString) .orElse("") + " " + type.getName() + "(..)"; acc.add(methodSignature, methodPattern, resembles); }); return super.visitCompilationUnit(cu, ctx); } }; } @Override public TreeVisitor getVisitor(Accumulator acc) { List methodMatchers = acc.getMethodMatchersTopK(); List> preconditions = new ArrayList<>(methodMatchers.size()); for (MethodMatcher m : methodMatchers) { preconditions.add(new UsesMethod<>(m)); } //noinspection unchecked return Preconditions.check(Preconditions.or(preconditions.toArray(new TreeVisitor[0])), new JavaIsoVisitor() { @Override public boolean isAcceptable(SourceFile sourceFile, ExecutionContext ctx) { return sourceFile instanceof J.CompilationUnit; } @Override public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) { getCursor().putMessage("count", new AtomicInteger()); getCursor().putMessage("max", new AtomicLong()); getCursor().putMessage("histogram", new EmbeddingPerformance.Histogram()); try { return super.visitCompilationUnit(cu, ctx); } finally { if (getCursor().getMessage("count", new AtomicInteger()).get() > 0) { Duration max = Duration.ofNanos(requireNonNull(getCursor().getMessage("max")).get()); performance.insertRow(ctx, new EmbeddingPerformance.Row(( (SourceFile) cu).getSourcePath().toString(), requireNonNull(getCursor().getMessage("count")).get(), requireNonNull(getCursor().getMessage("histogram")).getBuckets(), max)); } } } @Override public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { boolean matches = false; for (MethodMatcher methodMatcher : methodMatchers) { if (methodMatcher.matches(method)) { matches = true; break; } } if (!matches) { return super.visitMethodInvocation(method, ctx); } RelatedModelClient.Relatedness related = RelatedModelClient.getInstance() .getRelatedness(resembles, method.printTrimmed(getCursor())); for (Duration timing : related.getEmbeddingTimings()) { requireNonNull(getCursor().getNearestMessage("count")).incrementAndGet(); requireNonNull(getCursor().getNearestMessage("histogram")).add(timing); AtomicLong max = getCursor().getNearestMessage("max"); if (requireNonNull(max).get() < timing.toNanos()) { max.set(timing.toNanos()); } } int resultEmbeddingModels = related.isRelated(); boolean result; boolean calledGenerativeModel = false; if (resultEmbeddingModels == 0) { result = AgentGenerativeModelClient.getInstance().isRelated(resembles, method.printTrimmed(getCursor()), 0.5932); calledGenerativeModel = true; } else { result = resultEmbeddingModels == 1; } // Populate data table for debugging model's accuracy JavaSourceFile javaSourceFile = getCursor().firstEnclosing(JavaSourceFile.class); String source = javaSourceFile.getSourcePath().toString(); if (result) { codeSearchTable.insertRow(ctx, new CodeSearch.Row( source, method.printTrimmed(getCursor()), resembles, resultEmbeddingModels, calledGenerativeModel ? (result ? 1 : -1) : 0 )); } return result ? SearchResult.found(method) : super.visitMethodInvocation(method, ctx); } }); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy