All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dev.langchain4j.store.embedding.vespa.VespaQueryClient Maven / Gradle / Ivy

// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package dev.langchain4j.store.embedding.vespa;

import com.google.gson.GsonBuilder;
import dev.langchain4j.internal.Utils;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.GeneralSecurityException;
import java.security.KeyFactory;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.ArrayList;
import java.util.List;
import javax.net.ssl.*;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import org.bouncycastle.asn1.ASN1ObjectIdentifier;
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.asn1.x9.X9ObjectIdentifiers;
import org.bouncycastle.cert.X509CertificateHolder;
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import retrofit2.Retrofit;
import retrofit2.converter.gson.GsonConverterFactory;

/**
 * This Workaround is needed because of this request.
 * It will be redundant as soon as vespa-client is implemented. This class is copied from vespa-feed-client.
 * BouncyCastle integration for creating a {@link SSLContext} instance from PEM encoded material
 */
class VespaQueryClient {

  static final BouncyCastleProvider bcProvider = new BouncyCastleProvider();

  public static VespaQueryApi createInstance(String baseUrl, Path certificate, Path privateKey) {
    try {
      KeyStore keystore = KeyStore.getInstance("PKCS12");
      keystore.load(null);
      keystore.setKeyEntry("cert", privateKey(privateKey), new char[0], certificates(certificate));
      // Protocol version must be equal to TlsContext.SSL_CONTEXT_VERSION or higher
      SSLContext sslContext = SSLContext.getInstance("TLSv1.3");
      sslContext.init(createKeyManagers(keystore), null, /*Default secure random algorithm*/null);

      TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(
        TrustManagerFactory.getDefaultAlgorithm()
      );
      trustManagerFactory.init(keystore);

      OkHttpClient client = new OkHttpClient.Builder()
        .sslSocketFactory(sslContext.getSocketFactory(), (X509TrustManager) trustManagerFactory.getTrustManagers()[0])
        .addInterceptor(chain -> {
          // trick to format the query URL exactly how Vespa expects it (search/?query),
          // see https://docs.vespa.ai/en/reference/query-language-reference.html
          Request request = chain.request();
          HttpUrl url = request
            .url()
            .newBuilder()
            .removePathSegment(1)
            .addPathSegment("")
            .encodedQuery(request.url().encodedPathSegments().get(1))
            .build();
          request = request.newBuilder().url(url).build();
          return chain.proceed(request);
        })
        .build();

      Retrofit retrofit = new Retrofit.Builder()
        .baseUrl(Utils.ensureTrailingForwardSlash(baseUrl))
        .client(client)
        .addConverterFactory(GsonConverterFactory.create(new GsonBuilder().create()))
        .build();

      return retrofit.create(VespaQueryApi.class);
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
  }

  private static KeyManager[] createKeyManagers(KeyStore keystore) throws GeneralSecurityException {
    KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
    kmf.init(keystore, new char[0]);
    return kmf.getKeyManagers();
  }

  private static Certificate[] certificates(Path file) throws IOException, GeneralSecurityException {
    try (PEMParser parser = new PEMParser(Files.newBufferedReader(file))) {
      List result = new ArrayList<>();
      Object pemObject;
      while ((pemObject = parser.readObject()) != null) {
        result.add(toX509Certificate(pemObject));
      }
      if (result.isEmpty()) throw new IOException("File contains no PEM encoded certificates: " + file);
      return result.toArray(new Certificate[0]);
    }
  }

  private static PrivateKey privateKey(Path file) throws IOException, GeneralSecurityException {
    try (PEMParser parser = new PEMParser(Files.newBufferedReader(file))) {
      Object pemObject;
      while ((pemObject = parser.readObject()) != null) {
        if (pemObject instanceof PrivateKeyInfo) {
          PrivateKeyInfo keyInfo = (PrivateKeyInfo) pemObject;
          PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(keyInfo.getEncoded());
          return createKeyFactory(keyInfo).generatePrivate(keySpec);
        } else if (pemObject instanceof PEMKeyPair) {
          PEMKeyPair pemKeypair = (PEMKeyPair) pemObject;
          PrivateKeyInfo keyInfo = pemKeypair.getPrivateKeyInfo();
          return createKeyFactory(keyInfo).generatePrivate(new PKCS8EncodedKeySpec(keyInfo.getEncoded()));
        }
      }
      throw new IOException("Could not find private key in PEM file");
    }
  }

  private static X509Certificate toX509Certificate(Object pemObject) throws IOException, GeneralSecurityException {
    if (pemObject instanceof X509Certificate) return (X509Certificate) pemObject;
    if (pemObject instanceof X509CertificateHolder) {
      return new JcaX509CertificateConverter()
        .setProvider(bcProvider)
        .getCertificate((X509CertificateHolder) pemObject);
    }
    throw new IOException("Invalid type of PEM object: " + pemObject);
  }

  private static KeyFactory createKeyFactory(PrivateKeyInfo info) throws IOException, GeneralSecurityException {
    ASN1ObjectIdentifier algorithm = info.getPrivateKeyAlgorithm().getAlgorithm();
    if (X9ObjectIdentifiers.id_ecPublicKey.equals(algorithm)) {
      return KeyFactory.getInstance("EC", bcProvider);
    } else if (PKCSObjectIdentifiers.rsaEncryption.equals(algorithm)) {
      return KeyFactory.getInstance("RSA", bcProvider);
    } else {
      throw new IOException("Unknown key algorithm: " + algorithm);
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy