All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.solr.security.PKIAuthenticationPlugin Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.security;

import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.nio.ByteBuffer;
import java.security.Principal;
import java.security.PublicKey;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.http.HttpEntity;
import org.apache.http.HttpException;
import org.apache.http.HttpRequest;
import org.apache.http.HttpRequestInterceptor;
import org.apache.http.HttpResponse;
import org.apache.http.auth.BasicUserPrincipal;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.protocol.HttpContext;
import org.apache.http.util.EntityUtils;
import org.apache.solr.client.solrj.impl.Http2SolrClient;
import org.apache.solr.client.solrj.impl.HttpClientUtil;
import org.apache.solr.client.solrj.impl.HttpListenerFactory;
import org.apache.solr.client.solrj.impl.SolrHttpClientBuilder;
import org.apache.solr.common.util.Base64;
import org.apache.solr.common.util.ExecutorUtil;
import org.apache.solr.common.util.StrUtils;
import org.apache.solr.common.util.SuppressForbidden;
import org.apache.solr.common.util.Utils;
import org.apache.solr.core.CoreContainer;
import org.apache.solr.request.SolrRequestInfo;
import org.apache.solr.util.CryptoKeys;
import org.eclipse.jetty.client.api.Request;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static java.nio.charset.StandardCharsets.UTF_8;

public class PKIAuthenticationPlugin extends AuthenticationPlugin implements HttpClientBuilderPlugin {
  private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
  private final Map keyCache = new ConcurrentHashMap<>();
  private final PublicKeyHandler publicKeyHandler;
  private final CoreContainer cores;
  private final int MAX_VALIDITY = Integer.parseInt(System.getProperty("pkiauth.ttl", "15000"));
  private final String myNodeName;
  private final HttpHeaderClientInterceptor interceptor = new HttpHeaderClientInterceptor();
  private boolean interceptorRegistered = false;

  public boolean isInterceptorRegistered(){
    return interceptorRegistered;
  }

  public PKIAuthenticationPlugin(CoreContainer cores, String nodeName, PublicKeyHandler publicKeyHandler) {
    this.publicKeyHandler = publicKeyHandler;
    this.cores = cores;
    myNodeName = nodeName;
  }

  @Override
  public void init(Map pluginConfig) {
  }


  @SuppressForbidden(reason = "Needs currentTimeMillis to compare against time in header")
  @Override
  public boolean doAuthenticate(ServletRequest request, ServletResponse response, FilterChain filterChain) throws Exception {

    String requestURI = ((HttpServletRequest) request).getRequestURI();
    if (requestURI.endsWith(PublicKeyHandler.PATH)) {
      numPassThrough.inc();
      filterChain.doFilter(request, response);
      return true;
    }
    long receivedTime = System.currentTimeMillis();
    String header = ((HttpServletRequest) request).getHeader(HEADER);
    if (header == null) {
      //this must not happen
      log.error("No SolrAuth header present");
      numMissingCredentials.inc();
      filterChain.doFilter(request, response);
      return true;
    }

    List authInfo = StrUtils.splitWS(header, false);
    if (authInfo.size() < 2) {
      log.error("Invalid SolrAuth Header {}", header);
      numErrors.mark();
      filterChain.doFilter(request, response);
      return true;
    }

    String nodeName = authInfo.get(0);
    String cipher = authInfo.get(1);

    PKIHeaderData decipher = decipherHeader(nodeName, cipher);
    if (decipher == null) {
      log.error("Could not decipher a header {} . No principal set", header);
      numMissingCredentials.inc();
      filterChain.doFilter(request, response);
      return true;
    }
    if ((receivedTime - decipher.timestamp) > MAX_VALIDITY) {
      log.error("Invalid key request timestamp: {} , received timestamp: {} , TTL: {}", decipher.timestamp, receivedTime, MAX_VALIDITY);
      numErrors.mark();
      filterChain.doFilter(request, response);
      return true;
    }

    final Principal principal = "$".equals(decipher.userName) ?
        SU :
        new BasicUserPrincipal(decipher.userName);

    numAuthenticated.inc();
    filterChain.doFilter(getWrapper((HttpServletRequest) request, principal), response);
    return true;
  }

  private static HttpServletRequestWrapper getWrapper(final HttpServletRequest request, final Principal principal) {
    return new HttpServletRequestWrapper(request) {
      @Override
      public Principal getUserPrincipal() {
        return principal;
      }
    };
  }

  public static class PKIHeaderData {
    String userName;
    long timestamp;
  }

  private PKIHeaderData decipherHeader(String nodeName, String cipherBase64) {
    PublicKey key = keyCache.get(nodeName);
    if (key == null) {
      log.debug("No key available for node : {} fetching now ", nodeName);
      key = getRemotePublicKey(nodeName);
      log.debug("public key obtained {} ", key);
    }

    PKIHeaderData header = parseCipher(cipherBase64, key);
    if (header == null) {
      log.warn("Failed to decrypt header, trying after refreshing the key ");
      key = getRemotePublicKey(nodeName);
      return parseCipher(cipherBase64, key);
    } else {
      return header;
    }
  }

  private static  PKIHeaderData parseCipher(String cipher, PublicKey key) {
    byte[] bytes;
    try {
      bytes = CryptoKeys.decryptRSA(Base64.base64ToByteArray(cipher), key);
    } catch (Exception e) {
      log.error("Decryption failed , key must be wrong", e);
      return null;
    }
    String s = new String(bytes, UTF_8).trim();
    String[] ss = s.split(" ");
    if (ss.length < 2) {
      log.warn("Invalid cipher {} deciphered data {}", cipher, s);
      return null;
    }
    PKIHeaderData headerData = new PKIHeaderData();
    try {
      headerData.timestamp = Long.parseLong(ss[1]);
      headerData.userName = ss[0];
      log.debug("Successfully decrypted header {} {}", headerData.userName, headerData.timestamp);
      return headerData;
    } catch (NumberFormatException e) {
      log.warn("Invalid cipher {}", cipher);
      return null;
    }
  }

  PublicKey getRemotePublicKey(String nodename) {
    if (!cores.getZkController().getZkStateReader().getClusterState().getLiveNodes().contains(nodename)) return null;
    String url = cores.getZkController().getZkStateReader().getBaseUrlForNodeName(nodename);
    HttpEntity entity = null;
    try {
      String uri = url + PublicKeyHandler.PATH + "?wt=json&omitHeader=true";
      log.debug("Fetching fresh public key from : {}",uri);
      HttpResponse rsp = cores.getUpdateShardHandler().getDefaultHttpClient()
          .execute(new HttpGet(uri), HttpClientUtil.createNewHttpClientRequestContext());
      entity  = rsp.getEntity();
      byte[] bytes = EntityUtils.toByteArray(entity);
      Map m = (Map) Utils.fromJSON(bytes);
      String key = (String) m.get("key");
      if (key == null) {
        log.error("No key available from " + url + PublicKeyHandler.PATH);
        return null;
      } else {
        log.info("New Key obtained from  node: {} / {}", nodename, key);
      }
      PublicKey pubKey = CryptoKeys.deserializeX509PublicKey(key);
      keyCache.put(nodename, pubKey);
      return pubKey;
    } catch (Exception e) {
      log.error("Exception trying to get public key from : " + url, e);
      return null;
    } finally {
      Utils.consumeFully(entity);
    }

  }

  @Override
  public void setup(Http2SolrClient client) {
    final HttpListenerFactory.RequestResponseListener listener = new HttpListenerFactory.RequestResponseListener() {
      @Override
      public void onQueued(Request request) {
        if (cores.getAuthenticationPlugin() == null) {
          return;
        }
        if (!cores.getAuthenticationPlugin().interceptInternodeRequest(request)) {
          log.debug("{} secures this internode request", this.getClass().getSimpleName());
          generateToken().ifPresent(s -> request.header(HEADER, myNodeName + " " + s));
        } else {
          log.debug("{} secures this internode request", cores.getAuthenticationPlugin().getClass().getSimpleName());
        }
      }
    };
    client.addListenerFactory(() -> listener);
  }

  @Override
  public SolrHttpClientBuilder getHttpClientBuilder(SolrHttpClientBuilder builder) {
    HttpClientUtil.addRequestInterceptor(interceptor);
    interceptorRegistered = true;
    return builder;
  }

  public boolean needsAuthorization(HttpServletRequest req) {
    return req.getUserPrincipal() != SU;
  }

  private class HttpHeaderClientInterceptor implements HttpRequestInterceptor {

    public HttpHeaderClientInterceptor() {
    }

    @Override
    public void process(HttpRequest httpRequest, HttpContext httpContext) throws HttpException, IOException {
      if (cores.getAuthenticationPlugin() == null) {
        return;
      }
      if (!cores.getAuthenticationPlugin().interceptInternodeRequest(httpRequest, httpContext)) {
        log.debug("{} secures this internode request", this.getClass().getSimpleName());
        setHeader(httpRequest);
      } else {
        log.debug("{} secures this internode request", cores.getAuthenticationPlugin().getClass().getSimpleName());
      }
    }
  }

  @SuppressForbidden(reason = "Needs currentTimeMillis to set current time in header")
  private Optional generateToken() {
    SolrRequestInfo reqInfo = getRequestInfo();
    String usr;
    if (reqInfo != null) {
      Principal principal = reqInfo.getUserPrincipal();
      if (principal == null) {
        log.debug("principal is null");
        //this had a request but not authenticated
        //so we don't not need to set a principal
        return Optional.empty();
      } else {
        usr = principal.getName();
      }
    } else {
      if (!isSolrThread()) {
        //if this is not running inside a Solr threadpool (as in testcases)
        // then no need to add any header
        return Optional.empty();
      }
      //this request seems to be originated from Solr itself
      usr = "$"; //special name to denote the user is the node itself
    }

    String s = usr + " " + System.currentTimeMillis();

    byte[] payload = s.getBytes(UTF_8);
    byte[] payloadCipher = publicKeyHandler.keyPair.encrypt(ByteBuffer.wrap(payload));
    String base64Cipher = Base64.byteArrayToBase64(payloadCipher);
    return Optional.of(base64Cipher);
  }

  void setHeader(HttpRequest httpRequest) {
    generateToken().ifPresent(s -> httpRequest.setHeader(HEADER, myNodeName + " " + s));
  }

  boolean isSolrThread() {
    return ExecutorUtil.isSolrServerThread();
  }

  SolrRequestInfo getRequestInfo() {
    return SolrRequestInfo.getRequestInfo();
  }

  @Override
  public void close() throws IOException {
    HttpClientUtil.removeRequestInterceptor(interceptor);
    interceptorRegistered = false;
  }

  public String getPublicKey() {
    return publicKeyHandler.getPublicKey();
  }

  public static final String HEADER = "SolrAuth";
  public static final String NODE_IS_USER = "$";
  // special principal to denote the cluster member
  private static final Principal SU = new BasicUserPrincipal("$");
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy