All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.schmizz.sshj.connection.channel.AbstractChannel Maven / Gradle / Ivy

There is a newer version: 0.39.0
Show newest version
/*
 * Copyright (C)2009 - SSHJ Contributors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package net.schmizz.sshj.connection.channel;

import net.schmizz.concurrent.ErrorDeliveryUtil;
import net.schmizz.concurrent.Event;
import net.schmizz.sshj.common.*;
import net.schmizz.sshj.connection.Connection;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.transport.Transport;
import net.schmizz.sshj.transport.TransportException;
import org.slf4j.Logger;

import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.util.LinkedList;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;

public abstract class AbstractChannel
        implements Channel {

    private static final int REMOTE_MAX_PACKET_SIZE_CEILING = 1024 * 1024;

    /** Logger */
    protected final LoggerFactory loggerFactory;
    protected final Logger log;

    /** Transport layer */
    protected final Transport trans;
    /** Connection layer */
    protected final Connection conn;

    /** Channel type */
    private final String type;
    /** Channel ID */
    private final int id;
    /** Remote recipient ID */
    private int recipient;
    /** Remote character set */
    private final Charset remoteCharset;

    private boolean eof = false;

    private final Queue> chanReqResponseEvents = new LinkedList>();

    /* The lock used by to create the open & close events */
    private final ReentrantLock openCloseLock = new ReentrantLock();
    /** Channel open event */
    protected final Event openEvent;
    /** Channel close event */
    protected final Event closeEvent;
    /** Whether we have already sent a CHANNEL_CLOSE request to the server */
    private boolean closeRequested;

    /** Local window */
    protected final Window.Local lwin;
    /** stdout stream */
    private final ChannelInputStream in;

    /** Remote window */
    protected Window.Remote rwin;
    /** stdin stream */
    private ChannelOutputStream out;

    private volatile boolean autoExpand = false;

    protected AbstractChannel(Connection conn, String type) {
        this(conn, type, null);
    }
    protected AbstractChannel(Connection conn, String type, Charset remoteCharset) {
        this.conn = conn;
        this.loggerFactory = conn.getTransport().getConfig().getLoggerFactory();
        this.type = type;
        this.log = loggerFactory.getLogger(getClass());
        this.trans = conn.getTransport();

        this.remoteCharset = remoteCharset != null ? remoteCharset : IOUtils.UTF8;
        id = conn.nextID();

        lwin = new Window.Local(conn.getWindowSize(), conn.getMaxPacketSize(), loggerFactory);
        in = new ChannelInputStream(this, trans, lwin);

        openEvent = new Event("chan#" + id + " / " + "open", ConnectionException.chainer, openCloseLock, loggerFactory);
        closeEvent = new Event("chan#" + id + " / " + "close", ConnectionException.chainer, openCloseLock, loggerFactory);
    }

    protected void init(int recipient, long remoteWinSize, long remoteMaxPacketSize) {
        this.recipient = recipient;
        rwin = new Window.Remote(remoteWinSize, (int) Math.min(remoteMaxPacketSize, REMOTE_MAX_PACKET_SIZE_CEILING),
            conn.getTimeoutMs(), loggerFactory);
        out = new ChannelOutputStream(this, trans, rwin);
        log.debug("Initialized - {}", this);
    }

    @Override
    public boolean getAutoExpand() {
        return autoExpand;
    }

    @Override
    public int getID() {
        return id;
    }

    @Override
    public InputStream getInputStream() {
        return in;
    }

    @Override
    public int getLocalMaxPacketSize() {
        return lwin.getMaxPacketSize();
    }

    @Override
    public long getLocalWinSize() {
        return lwin.getSize();
    }

    @Override
    public OutputStream getOutputStream() {
        return out;
    }

    @Override
    public int getRecipient() {
        return recipient;
    }

    @Override
    public Charset getRemoteCharset() {
        return remoteCharset;
    }

    @Override
    public int getRemoteMaxPacketSize() {
        return rwin.getMaxPacketSize();
    }

    @Override
    public long getRemoteWinSize() {
        return rwin.getSize();
    }

    @Override
    public String getType() {
        return type;
    }

    @Override
    public void handle(Message msg, SSHPacket buf)
            throws ConnectionException, TransportException {
        switch (msg) {

            case CHANNEL_DATA:
                receiveInto(in, buf);
                break;

            case CHANNEL_EXTENDED_DATA:
                gotExtendedData(buf);
                break;

            case CHANNEL_WINDOW_ADJUST:
                gotWindowAdjustment(buf);
                break;

            case CHANNEL_REQUEST:
                gotChannelRequest(buf);
                break;

            case CHANNEL_SUCCESS:
                gotResponse(true);
                break;

            case CHANNEL_FAILURE:
                gotResponse(false);
                break;

            case CHANNEL_EOF:
                gotEOF();
                break;

            case CHANNEL_CLOSE:
                gotClose();
                break;

            default:
                gotUnknown(msg, buf);
                break;
        }
    }

    @Override
    public boolean isEOF() {
        return eof;
    }

    @Override
    public LoggerFactory getLoggerFactory() {
        return loggerFactory;
    }

    private void gotClose()
            throws TransportException {
        log.debug("Got close");
        try {
            closeAllStreams();
            sendClose();
        } finally {
            finishOff();
        }
    }

    /** Called when all I/O streams should be closed. Subclasses can override but must call super. */
    protected void closeAllStreams() {
        IOUtils.closeQuietly(in, out);
    }

    @Override
    public void notifyError(SSHException error) {
        log.debug("Channel #{} got notified of {}", getID(), error.toString());

        ErrorDeliveryUtil.alertEvents(error, openEvent, closeEvent);
        ErrorDeliveryUtil.alertEvents(error, chanReqResponseEvents);

        in.notifyError(error);
        if (out != null)
            out.notifyError(error);

        finishOff();
    }

    @Override
    public void setAutoExpand(boolean autoExpand) {
        this.autoExpand = autoExpand;
    }

    @Override
    public void close()
            throws ConnectionException, TransportException {
        openCloseLock.lock();
        try {
            if (isOpen()) {
                try {
                    sendClose();
                } catch (TransportException e) {
                    if (!closeEvent.inError())
                        throw e;
                }
                closeEvent.await(conn.getTimeoutMs(), TimeUnit.MILLISECONDS);
            }
        } finally {
            openCloseLock.unlock();
        }
    }

    public void join()
            throws ConnectionException {
        closeEvent.await();
    }

    public void join(long timeout, TimeUnit unit)
            throws ConnectionException {
        closeEvent.await(timeout, unit);
    }

    protected void sendClose()
            throws TransportException {
        openCloseLock.lock();
        try {
            if (!closeRequested) {
                log.debug("Sending close");
                trans.write(newBuffer(Message.CHANNEL_CLOSE));
            }
        } finally {
            closeRequested = true;
            openCloseLock.unlock();
        }
    }

    @Override
    public boolean isOpen() {
        openCloseLock.lock();
        try {
            return openEvent.isSet() && !closeEvent.isSet() && !closeRequested;
        } finally {
            openCloseLock.unlock();
        }
    }

    private void gotChannelRequest(SSHPacket buf)
            throws ConnectionException, TransportException {
        final String reqType;
        try {
            reqType = buf.readString();
            buf.readBoolean(); // We don't care about the 'want-reply' value
        } catch (Buffer.BufferException be) {
            throw new ConnectionException(be);
        }
        log.debug("Got chan request for `{}`", reqType);
        handleRequest(reqType, buf);
    }

    private void gotWindowAdjustment(SSHPacket buf)
            throws ConnectionException {
        final long howMuch;
        try {
            howMuch = buf.readUInt32();
        } catch (Buffer.BufferException be) {
            throw new ConnectionException(be);
        }
        log.debug("Received window adjustment for {} bytes", howMuch);
        rwin.expand(howMuch);
    }

    protected void finishOff() {
        conn.forget(this);
        closeEvent.set();
    }

    protected void gotExtendedData(SSHPacket buf)
            throws ConnectionException, TransportException {
        throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR,
                                      "Extended data not supported on " + type + " channel");
    }

    protected void gotUnknown(Message msg, SSHPacket buf)
            throws ConnectionException, TransportException {
        log.warn("Got unknown packet with type {}", msg);

    }

    protected void handleRequest(String reqType, SSHPacket buf)
            throws ConnectionException, TransportException {
        trans.write(newBuffer(Message.CHANNEL_FAILURE));
    }

    protected SSHPacket newBuffer(Message cmd) {
        return new SSHPacket(cmd).putUInt32(recipient);
    }

    protected void receiveInto(ChannelInputStream stream, SSHPacket buf)
            throws ConnectionException, TransportException {
        final int len;
        try {
            len = buf.readUInt32AsInt();
        } catch (Buffer.BufferException be) {
            throw new ConnectionException(be);
        }
        if (len < 0 || len > getLocalMaxPacketSize() || len > buf.available()) {
            throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, "Bad item length: " + len);
        }
        if (log.isTraceEnabled()) {
            log.trace("IN #{}: {}", id, ByteArrayUtils.printHex(buf.array(), buf.rpos(), len));
        }
        stream.receive(buf.array(), buf.rpos(), len);
    }

    protected Event sendChannelRequest(String reqType, boolean wantReply,
                                                            Buffer.PlainBuffer reqSpecific)
            throws TransportException {
        log.debug("Sending channel request for `{}`", reqType);
        synchronized (chanReqResponseEvents) {
            trans.write(
                    newBuffer(Message.CHANNEL_REQUEST)
                            .putString(reqType)
                            .putBoolean(wantReply)
                            .putBuffer(reqSpecific)
            );

            Event responseEvent = null;
            if (wantReply) {
                responseEvent = new Event("chan#" + id + " / " + "chanreq for " + reqType,
                                                               ConnectionException.chainer, loggerFactory);
                chanReqResponseEvents.add(responseEvent);
            }
            return responseEvent;
        }
    }

    private void gotResponse(boolean success)
            throws ConnectionException {
        synchronized (chanReqResponseEvents) {
            final Event responseEvent = chanReqResponseEvents.poll();
            if (responseEvent != null) {
                if (success)
                    responseEvent.set();
                else
                    responseEvent.deliverError(new ConnectionException("Request failed"));
            } else
                throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR,
                                              "Received response to channel request when none was requested");
        }
    }

    private void gotEOF()
            throws TransportException {
        log.debug("Got EOF");
        eofInputStreams();
    }

    /** Called when EOF has been received. Subclasses can override but must call super. */
    protected void eofInputStreams() {
        in.eof();
        eof = true;
    }

    @Override
    public String toString() {
        return "< " + type + " channel: id=" + id + ", recipient=" + recipient + ", localWin=" + lwin + ", remoteWin="
                + rwin + " >";
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy