All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.avro.ipc.SaslSocketTransceiver Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.avro.ipc;

import java.io.IOException;
import java.io.EOFException;
import java.io.UnsupportedEncodingException;
import java.net.SocketAddress;
import java.nio.channels.SocketChannel;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;

import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslServer;

import org.apache.avro.Protocol;
import org.apache.avro.util.ByteBufferOutputStream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** A {@link Transceiver} that uses {@link javax.security.sasl} for
 * authentication and encryption. */
public class SaslSocketTransceiver extends Transceiver {
  private static final Logger LOG =
    LoggerFactory.getLogger(SaslSocketTransceiver.class);

  private static final ByteBuffer EMPTY = ByteBuffer.allocate(0);

  private static enum Status { START, CONTINUE, FAIL, COMPLETE }

  private SaslParticipant sasl;
  private SocketChannel channel;
  private boolean dataIsWrapped;
  private boolean saslResponsePiggybacked;

  private Protocol remote;

  private ByteBuffer readHeader = ByteBuffer.allocate(4);
  private ByteBuffer writeHeader = ByteBuffer.allocate(4);
  private ByteBuffer zeroHeader = ByteBuffer.allocate(4).putInt(0);

  /** Create using SASL's anonymous (RFC 2245) mechanism. */
  public SaslSocketTransceiver(SocketAddress address) throws IOException {
    this(address, new AnonymousClient());
  }

  /** Create using the specified {@link SaslClient}. */
  public SaslSocketTransceiver(SocketAddress address, SaslClient saslClient)
    throws IOException {
    this.sasl = new SaslParticipant(saslClient);
    this.channel = SocketChannel.open(address);
    this.channel.socket().setTcpNoDelay(true);
    LOG.debug("open to {}", getRemoteName());
    open(true);
  }

  /** Create using the specified {@link SaslServer}. */
  public SaslSocketTransceiver(SocketChannel channel, SaslServer saslServer)
    throws IOException {
    this.sasl = new SaslParticipant(saslServer);
    this.channel = channel;
    LOG.debug("open from {}", getRemoteName());
    open(false);
  }

  @Override public boolean isConnected() { return remote != null; }

  @Override public void setRemote(Protocol remote) {
    this.remote = remote;
  }

  @Override public Protocol getRemote() {
    return remote;
  }
  @Override public String getRemoteName() {
    return channel.socket().getRemoteSocketAddress().toString();
  }

  @Override
  public synchronized List transceive(List request)
    throws IOException {
    if (saslResponsePiggybacked) {                // still need to read response
      saslResponsePiggybacked = false;
      Status status  = readStatus();
      ByteBuffer frame = readFrame();
      switch (status) {
      case COMPLETE:
        break;
      case FAIL:
        throw new SaslException("Fail: "+toString(frame));
      default:
        throw new IOException("Unexpected SASL status: "+status);
      }
    }
    return super.transceive(request);
  }

  private void open(boolean isClient) throws IOException {
    LOG.debug("beginning SASL negotiation");

    if (isClient) {
      ByteBuffer response = EMPTY;
      if (sasl.client.hasInitialResponse())
        response = ByteBuffer.wrap(sasl.evaluate(response.array()));
      write(Status.START, sasl.getMechanismName(), response);
      if (sasl.isComplete())
        saslResponsePiggybacked = true;
    }

    while (!sasl.isComplete()) {
      Status status  = readStatus();
      ByteBuffer frame = readFrame();
      switch (status) {
      case START:
        String mechanism = toString(frame);
        frame = readFrame();
        if (!mechanism.equalsIgnoreCase(sasl.getMechanismName())) {
          write(Status.FAIL, "Wrong mechanism: "+mechanism);
          throw new SaslException("Wrong mechanism: "+mechanism);
        }
      case CONTINUE:
        byte[] response;
        try {
          response = sasl.evaluate(frame.array());
          status = sasl.isComplete() ? Status.COMPLETE : Status.CONTINUE;
        } catch (SaslException e) {
          response = e.toString().getBytes("UTF-8");
          status = Status.FAIL;
        }
        write(status, response!=null ? ByteBuffer.wrap(response) : EMPTY);
        break;
      case COMPLETE:
        sasl.evaluate(frame.array());
        if (!sasl.isComplete())
          throw new SaslException("Expected completion!");
        break;
      case FAIL:
        throw new SaslException("Fail: "+toString(frame));
      default:
        throw new IOException("Unexpected SASL status: "+status);
      }
    }
    LOG.debug("SASL opened");

    String qop = (String) sasl.getNegotiatedProperty(Sasl.QOP);
    LOG.debug("QOP = {}", qop);
    dataIsWrapped = (qop != null && !qop.equalsIgnoreCase("auth"));
  }

  private String toString(ByteBuffer buffer) throws IOException {
    try {
      return new String(buffer.array(), "UTF-8");
    } catch (UnsupportedEncodingException e) {
      throw new IOException(e.toString(), e);
    }
  }

  @Override public synchronized List readBuffers()
    throws IOException {
    List buffers = new ArrayList();
    while (true) {
      ByteBuffer buffer = readFrameAndUnwrap();
      if (buffer.remaining() == 0)
        return buffers;
      buffers.add(buffer);
    }
  }

  private Status readStatus() throws IOException {
    ByteBuffer buffer = ByteBuffer.allocate(1);
    read(buffer);
    int status = buffer.get();
    if (status > Status.values().length)
      throw new IOException("Unexpected SASL status byte: "+status);
    return Status.values()[status];
  }

  private ByteBuffer readFrameAndUnwrap() throws IOException {
    ByteBuffer frame = readFrame();
    if (!dataIsWrapped)
      return frame;
    ByteBuffer unwrapped = ByteBuffer.wrap(sasl.unwrap(frame.array()));
    LOG.debug("unwrapped data of length: {}", unwrapped.remaining());
    return unwrapped;
  }

  private ByteBuffer readFrame() throws IOException {
    read(readHeader);
    ByteBuffer buffer = ByteBuffer.allocate(readHeader.getInt());
    LOG.debug("about to read: {} bytes", buffer.capacity());
    read(buffer);
    return buffer;
  }

  private void read(ByteBuffer buffer) throws IOException {
    buffer.clear();
    while (buffer.hasRemaining())
      if (channel.read(buffer) == -1)
        throw new EOFException();
    buffer.flip();
  }

  @Override public synchronized void writeBuffers(List buffers)
    throws IOException {
    if (buffers == null) return;                  // no data to write
    List writes = new ArrayList(buffers.size()*2+1);
    int currentLength = 0;
    ByteBuffer currentHeader = writeHeader;
    for (ByteBuffer buffer : buffers) {           // gather writes
      if (buffer.remaining() == 0) continue;      // ignore empties
      if (dataIsWrapped) {
        LOG.debug("wrapping data of length: {}", buffer.remaining());
        buffer = ByteBuffer.wrap(sasl.wrap(buffer.array(), buffer.position(),
                                           buffer.remaining()));
      }
      int length = buffer.remaining();
      if (!dataIsWrapped                          // can append buffers on wire
          && (currentLength + length) <= ByteBufferOutputStream.BUFFER_SIZE) {
        if (currentLength == 0)
          writes.add(currentHeader);
        currentLength += length;
        currentHeader.clear();
        currentHeader.putInt(currentLength);
        LOG.debug("adding {} to write, total now {}", length, currentLength);
      } else {
        currentLength = length;
        currentHeader = ByteBuffer.allocate(4).putInt(length);
        writes.add(currentHeader);
        LOG.debug("planning write of {}", length);
      }
      currentHeader.flip();
      writes.add(buffer);
    }
    zeroHeader.flip();                            // zero-terminate
    writes.add(zeroHeader);

    writeFully(writes.toArray(new ByteBuffer[writes.size()]));
  }

  private void write(Status status, String prefix, ByteBuffer response)
    throws IOException {
    LOG.debug("write status: {} {}", status, prefix);
    write(status, prefix);
    write(response);
  }

  private void write(Status status, String response) throws IOException {
    write(status, ByteBuffer.wrap(response.getBytes("UTF-8")));
  }

  private void write(Status status, ByteBuffer response) throws IOException {
    LOG.debug("write status: {}", status);
    ByteBuffer statusBuffer = ByteBuffer.allocate(1);
    statusBuffer.clear();
    statusBuffer.put((byte)(status.ordinal())).flip();
    writeFully(statusBuffer);
    write(response);
  }

  private void write(ByteBuffer response) throws IOException {
    LOG.debug("writing: {}", response.remaining());
    writeHeader.clear();
    writeHeader.putInt(response.remaining()).flip();
    writeFully(writeHeader, response);
  }

  private void writeFully(ByteBuffer... buffers) throws IOException {
    int length = buffers.length;
    int start = 0;
    do {
      channel.write(buffers, start, length-start);
      while (buffers[start].remaining() == 0) {
        start++;
        if (start == length)
          return;
      }
    } while (true);
  }

  @Override public void close() throws IOException {
    if (channel.isOpen()) {
      LOG.info("closing to "+getRemoteName());
      channel.close();
    }
    sasl.dispose();
  }

  /**
   * Used to abstract over the SaslServer and
   * SaslClient classes, which share a lot of their interface, but
   * unfortunately don't share a common superclass.
   */
  private static class SaslParticipant {
    // One of these will always be null.
    public SaslServer server;
    public SaslClient client;

    public SaslParticipant(SaslServer server) {
      this.server = server;
    }

    public SaslParticipant(SaslClient client) {
      this.client = client;
    }

    public String getMechanismName() {
      if (client != null)
        return client.getMechanismName();
      else
        return server.getMechanismName();
    }

    public boolean isComplete() {
      if (client != null)
        return client.isComplete();
      else
        return server.isComplete();
    }

    public void dispose() throws SaslException {
      if (client != null)
        client.dispose();
      else
        server.dispose();
    }

    public byte[] unwrap(byte[] buf) throws SaslException {
      if (client != null)
        return client.unwrap(buf, 0, buf.length);
      else
        return server.unwrap(buf, 0, buf.length);
    }

    public byte[] wrap(byte[] buf) throws SaslException {
      return wrap(buf, 0, buf.length);
    }

    public byte[] wrap(byte[] buf, int start, int len) throws SaslException {
      if (client != null)
        return client.wrap(buf, start, len);
      else
        return server.wrap(buf, start, len);
    }

    public Object getNegotiatedProperty(String propName) {
      if (client != null)
        return client.getNegotiatedProperty(propName);
      else
        return server.getNegotiatedProperty(propName);
    }

    public byte[] evaluate(byte[] buf) throws SaslException {
      if (client != null)
        return client.evaluateChallenge(buf);
      else
        return server.evaluateResponse(buf);
    }

  }

  private static class AnonymousClient implements SaslClient {
    public String getMechanismName() { return "ANONYMOUS"; }
    public boolean hasInitialResponse() { return true; }
    public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
      try {
        return System.getProperty("user.name").getBytes("UTF-8");
      } catch (IOException e) {
        throw new SaslException(e.toString());
      }
    }
    public boolean isComplete() { return true; }
    public byte[] unwrap(byte[] incoming, int offset, int len) {
      throw new UnsupportedOperationException();
    }
    public byte[] wrap(byte[] outgoing, int offset, int len) {
      throw new UnsupportedOperationException();
    }
    public Object getNegotiatedProperty(String propName) { return null; }
    public void dispose() {}
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy