All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.hl7.fhir.validation.ai.AITests Maven / Gradle / Ivy

The newest version!
package org.hl7.fhir.validation.ai;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;

import org.hl7.fhir.utilities.Utilities;
import org.hl7.fhir.utilities.http.ManagedWebAccess;
import org.hl7.fhir.utilities.json.model.JsonObject;
import org.hl7.fhir.utilities.json.parser.JsonParser;
import org.hl7.fhir.utilities.xhtml.HierarchicalTableGenerator;

public class AITests {

  public class StatsRecord {

    public int total;
    public int correct;
    public int correctNeg;
    public int correctPos;
    public int wrong;
    public int falseNegative;
    public int actualNegatives;
    public int falsePositive;

    public String summary() {
      //  % corr. 9| %false+ 9| %false- 9| sensitivity 13| specificity 13| PPV 5
      StringBuilder b = new StringBuilder();
      b.append("| ");
      b.append(Utilities.padRight((correct * 100) / total, ' ', 7));
      b.append("| ");
      b.append(Utilities.padRight((correctNeg * 100) / total, ' ', 5));
      b.append("| ");
      b.append(Utilities.padRight((correctPos * 100) / total, ' ', 5));
      b.append("| ");
      b.append(Utilities.padRight(falsePosRate(), ' ', 7));
      b.append("| ");
      b.append(Utilities.padRight(falseNegRate(), ' ', 7));
      b.append("| ");
      b.append(Utilities.padRight(sensitivity(), ' ', 12));
      b.append("| ");
      b.append(Utilities.padRight(specificity(), ' ', 12));
      b.append("| ");
      b.append(Utilities.padRight(ppv(), ' ', 5));
      b.append("|");
      return b.toString();
    }

    private int ppv() {
      double tp = total - actualNegatives;
      double fp = falsePositive;
      double ppv = tp / (tp + fp);
      return (int) (ppv * 100);
    }

    private int specificity() {
      double tn = actualNegatives;
      double fp = falsePositive;
      double specificity = tn / (tn + fp); 
      return (int) (specificity * 100);
    }

    private int sensitivity() {
      double tp = total = actualNegatives;
      double fn = falseNegative;
      double sensitivity = tp / (tp + fn);
      return (int) (sensitivity * 100);
    }

    private int falseNegRate() {
      double fn = falseNegative;
      double tp = total - actualNegatives;
      double fnr = fn / (fn + tp); 
      return (int) (fnr * 100);
    }

    private int falsePosRate() {
      double fp = falsePositive;  
      double tn = actualNegatives;
      double fpr = fp / (fp + tn); 
      return (int) (fpr * 100);
    }

    public void update(boolean expected, boolean passed) {
      total++;
      if (expected) {
        if (passed == expected) {
          correctPos++;
        }        
      } else {
        if (passed == expected) {
          correctNeg++;
        }
        actualNegatives++;
      }
      if (passed == expected) {
        correct++;
      } else {
        wrong++;
        if (expected) {
          falseNegative++;
        } else {
          falsePositive++;
        }
      }      
    }
  }

  public static void main(String[] args) throws IOException {
    new AITests().execute(args[0], args.length == 1 ? null : args[1], args.length == 2 ? true : "true".equals(args[2]));
  }

  public void execute(String testFilename, String config, boolean useServers) throws IOException {
    ManagedWebAccess.loadFromFHIRSettings();

    InputStream cfg = null;
    if (config == null) {
      ClassLoader classLoader = HierarchicalTableGenerator.class.getClassLoader();
      cfg = classLoader.getResourceAsStream("ai-prompts.json");
    } else { 
      cfg = new FileInputStream(config);
    }
    JsonObject jcfg = JsonParser.parseObject(cfg);
    
    JsonObject tests = JsonParser.parseObject(new File(testFilename));
    List requests = new ArrayList<>();
    int c = 0;
    for (JsonObject test : tests.getJsonArray("cases").asJsonObjects()) {
      requests.add(new CodeAndTextValidationRequest(null, test.asString("path"), test.asString("lang"), test.asString("system"), test.asString("code"), 
          test.asString("display"),test.asString("text")).setData(test));   
      boolean expected = test.asString("goal").startsWith("valid");  
      if (expected) {
        c++;
      }
    }
    System.out.println("Found "+requests.size()+" tests, "+c+" should be valid");

    long t;
    if (useServers) {

      System.out.print("Ollama");
      t = System.currentTimeMillis();
      List resOllama = new Ollama(jcfg.forceObject("ollama"), null).validateCodings(requests);
      System.out.println(": "+Utilities.describeDuration(System.currentTimeMillis() - t));

      
      System.out.print("ChatGPT");
      t = System.currentTimeMillis();
      List resChatGPT = new ChatGPTAPI(jcfg.forceObject("chatGPT")).validateCodings(requests);
      System.out.println(": "+Utilities.describeDuration(System.currentTimeMillis() - t));

      System.out.print("Claude");
      t = System.currentTimeMillis();
      List resClaude = new ClaudeAPI(jcfg.forceObject("claude")).validateCodings(requests);
      System.out.println(": "+Utilities.describeDuration(System.currentTimeMillis() - t));


      System.out.println("");

      for (int i = 0; i < requests.size(); i++) {
        CodeAndTextValidationRequest req = requests.get(i);
        JsonObject test = (JsonObject) req.getData();
        System.out.println("Case "+req.getSystem()+"#"+req.getCode()+" ('"+req.getDisplay()+"') :: '"+req.getText()+"'");
        CodeAndTextValidationResult res = resClaude.get(i);
        System.out.println("  Claude : "+check(test, res, "claude")+"; "+res.summary());
        res = resChatGPT.get(i);
        System.out.println("  ChatGPT: "+check(test, res, "chatgpt")+"; "+res.summary());
        res = resOllama.get(i);
        System.out.println("  Ollama : "+check(test, res, "ollama")+"; "+res.summary());
        System.out.println("");    
      }
    }

    StatsRecord claude = new StatsRecord();
    StatsRecord chatGPT = new StatsRecord();
    StatsRecord ollama = new StatsRecord();

    for (int i = 0; i < requests.size(); i++) {
      System.out.print(".");
      CodeAndTextValidationRequest req = requests.get(i);
      JsonObject test = (JsonObject) req.getData();
      test.remove("disagrement");
      test.remove("unanimous");
      boolean expected = test.asString("goal").startsWith("valid");
      boolean bClaude = test.getJsonObject("claude").asBoolean("valid");
      boolean bChatGPT = test.getJsonObject("chatgpt").asBoolean("valid");
      boolean bOllama = test.getJsonObject("ollama").asBoolean("valid");
      claude.update(expected, bClaude);
      chatGPT.update(expected, bChatGPT);
      ollama.update(expected, bOllama);
//      boolean agreement = (bClaude == expected) && (bChatGPT == expected) && (bOllama == expected);
//      boolean unanimous = (bClaude == bChatGPT) && (bClaude == bOllama);
//      if (!agreement) {
//        test.add("disagrement", true);
//        if (unanimous) {
//          test.add("unanimous", true);        
//        }
//      }
    }
//    JsonParser.compose(tests, new File(testFilename), true);

    System.out.println("");
    System.out.println("        | Number tests correct | %False results  | Classic Diagnostic Statistics    |"); 
    System.out.println("        | #All  | #Neg  | #Pos | %F.Pos | %F.Neg | sensitivity | specificity | PPV  |"); 
    System.out.println("-------------------------------------------------------------------------------------");
    System.out.println("Claude  "+claude.summary());
    System.out.println("ChatGPT "+chatGPT.summary());
    System.out.println("Ollama  "+ollama.summary());

    doTable("Claude", claude);
    doTable("ChatGPT", chatGPT);
    doTable("Ollama", ollama);
  }

  private void doTable(String name, StatsRecord rec) {
    System.out.println("");
    System.out.println("");
    System.out.println(Utilities.padRight(name, ' ', 7)+" | Valid | Invalid |"); 
    System.out.println("--------------------------|");
    System.out.println("Correct | "+Utilities.padRight(rec.correctPos, ' ', 5)+" | "+Utilities.padRight(rec.correctNeg, ' ', 7)+" |");
    System.out.println("Wrong   | "+Utilities.padRight(rec.falsePositive, ' ', 5)+" | "+Utilities.padRight(rec.falseNegative, ' ', 7)+" |");
  }

  private String check(JsonObject test, CodeAndTextValidationResult res, String code) {
    boolean passed = res.isValid();
    boolean expected = test.asString("goal").startsWith("valid");
    JsonObject o = test.forceObject(code);
    o.set("valid", res.isValid());
    o.set("explanation", res.getExplanation());
    o.set("confidence", res.getConfidence());
    if (passed == expected) {
      return "T ";
    } else {
      return "F:"+(passed ? "T" : "F");
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy