All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.marklogic.http.HttpChannel Maven / Gradle / Ivy

/*
 * Copyright (c) 2023 MarkLogic Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.marklogic.http;

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level;
import java.util.logging.Logger;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.marklogic.io.ChunkedInputStream;
import com.marklogic.io.LengthLimitedInputStream;
import com.marklogic.io.SslByteChannel;

public class HttpChannel {
    // IN CASE OF EMERGENCY BREAK GLASS
    public static final String USEHTTP_PROPERTY_NAME = "xcc.httpcompliant";    
    public static final String RCV_TIME_HEADER = "X-XCC-Received";

    static final int DEFAULT_BUFFER_SIZE = 64 * 1024;
    static final int MINIMUM_BUFFER_SIZE = 1024;
    static final int MAXIMUM_BUFFER_SIZE = 32 * 1024 * 1024;
    static final int KEEP_ALIVE_TIME = 5;
    // Keep-alive timeout for inserting a batch. This value is hard-coded in 
    // HTTPRequestTask::handleXDBCInsert
    static final int KEEP_ALIVE_TIME_MULTI_TXN = 120;

    private final ByteChannel channel;
    private final HttpHeaders requestHeaders = new HttpHeaders();
    private final HttpHeaders responseHeaders = new HttpHeaders();
    private final InputStream inStream;
    private final ByteBuffer bodyBuffer;
    private final Logger logger;
    private int bufSize = Integer.MAX_VALUE;

    private boolean suppressHeaders = false;
    private boolean closeOutputIfNoContentLength = false;
    private boolean headersParsed = false;
    private boolean headersWritten = false;
    private static AtomicBoolean useHTTP =
        new AtomicBoolean("true".equalsIgnoreCase(
            System.getProperty(USEHTTP_PROPERTY_NAME)));
    private boolean commit = true;

    private boolean isChunked() {
        String te = getRequestHeader("Transfer-Encoding");
        return te == null ? false : te.equalsIgnoreCase("chunked");
    }
    
    private boolean isKeepAlive() {
        String ka = getRequestHeader("Connection");
        return ka == null ? false : ka.equalsIgnoreCase("keep-alive");
    }


    public static boolean isUseHTTP() {
        return useHTTP.get();
    }
    
    public static void setUseHTTP(boolean val) {
        useHTTP.set(val);
    }

    public void setCommit(boolean val) {
        commit = val;
    }
    
    // TODO: Add more logging calls to this class

    // --------------------------------------------------------------

    public HttpChannel(ByteChannel channel, String method, String path, int bufferSize, int timeoutMillis, Logger logger) {
        this.channel = channel;
        this.logger = (logger == null) ? Logger.getLogger(getClass().getName()) : logger;

        requestHeaders.setRequestValues(method, path, 
                isUseHTTP()?"HTTP/1.1":"XDBC/1.0");
        
        if (this.logger.isLoggable(Level.FINE)) {
            this.logger.fine("XDBC request: " + requestHeaders.getRequestLine());
        }
        bodyBuffer = allocBuffer(bufferSize);

        inStream = (new ChannelInputStream(channel, bodyBuffer, timeoutMillis));
    }

    public void reset(String method, String path) {
        suppressHeaders = false;
        closeOutputIfNoContentLength = false;
        headersParsed = false;
        headersWritten = false;
        requestHeaders.clear();
        responseHeaders.clear();
        bodyBuffer.clear();

        requestHeaders.setRequestValues(method, path, 
            isUseHTTP()?"HTTP/1.1":"XDBC/1.0");
    }

    // --------------------------------------------------------------

    public ByteChannel getChannel() {
        return channel;
    }

    public void setCloseOutputIfNoContentLength(boolean value) {
        this.closeOutputIfNoContentLength = value;
    }

    // --------------------------------------------------------------

    public int write(byte[] bytes, int offset, int length) throws IOException {
        int srcRemaining = length;

        while (srcRemaining > 0) {
            if (bodyBuffer.remaining() == 0) {
                flushRequest(false);
            }

            int len = Math.min(srcRemaining, bodyBuffer.remaining());

            bodyBuffer.put(bytes, offset + (length - srcRemaining), len);

            srcRemaining -= len;
        }

        return (length);
    }

    public int write(byte[] bytes) throws IOException {
        return (write(bytes, 0, bytes.length));
    }

    public void writeString(String value) throws IOException {
        write(value.getBytes("UTF-8"));
    }

    public void write(ByteBuffer buffer) throws IOException {
        if (buffer.limit() < bodyBuffer.remaining()) {
            write(buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining());

            return;
        }

        flushRequest(false);

        writeBuffer(channel, buffer);
    }

    // --------------------------------------------------------------

    public InputStream getResponseStream() throws IOException {
        receiveMode();

        if (isResponseChunked()) {
            return new ChunkedInputStream(inStream);
        }
        
        if (getResponseContentLength() != -1) {
            return new LengthLimitedInputStream(inStream, getResponseContentLength());
        }

        return (inStream);
    }

    // --------------------------------------------------------------

    public void setRequestHeader(String header, String value) {
        requestHeaders.setHeader(header, value);
    }

    public String getRequestHeader(String header) {
        return (requestHeaders.getHeader(header));
    }

    // --------------------------------------------------------------

    public void setRequestContentType(String value) {
        requestHeaders.setHeader("Content-Type", value);
    }

    public void setRequestContentLength(int length) {
        requestHeaders.setHeader("Content-Length", "" + length);
    }

    // --------------------------------------------------------------

    public String getResponseHeader(String headerName) throws IOException {
        receiveMode();

        return (responseHeaders.getHeaderNormalized(headerName));
    }
    
    public List getResponseHeaders(String headerName) 
            throws IOException {
        receiveMode();
        
        return responseHeaders.getAllHeadersNormalized(headerName);
    }

    public int getResponseCode() throws IOException {
        receiveMode();

        return responseHeaders.getResponseCode();
    }

    public String getResponseMessage() throws IOException {
        receiveMode();

        return responseHeaders.getResponseMessage();
    }

    public int getResponseContentLength() throws IOException {
        receiveMode();

        return responseHeaders.getContentLength();
    }

    public String getResponseContentType() throws IOException {
        receiveMode();

        return responseHeaders.getContentType();
    }

    public String getResponseContentTypeField(String fieldName) throws IOException {
        receiveMode();

        return responseHeaders.getContentTypeField(fieldName);
    }

    public String getResponseContentBoundary() throws IOException {
        getResponseContentType(); // insure headers are parsed

        return (responseHeaders.getHeaderSubValue("content-type", "boundary", ";"));
    }

    public String getReponseCookieValue(String key) throws IOException {
        receiveMode();
        return responseHeaders.getHeaderSubValue("set-cookie", key, ";");
    }
    
    public boolean isResponseChunked() throws IOException {
        receiveMode();

        return responseHeaders.isChunked();
    }

    public long getResponseHeaderRecvTime() throws IOException {
        receiveMode();

        // check normalized first, unit testing hack.  Real header is mixed case.
        String val = responseHeaders.getHeaderNormalized(RCV_TIME_HEADER);

        if (val == null) {
            val = responseHeaders.getHeader(RCV_TIME_HEADER);
        }

        if (val == null) {
            return 0;
        }

        return Long.parseLong(val);
    }

    public long getResponseKeepaliveExpireTime() throws IOException {
        receiveMode();

        int keepAliveSeconds = getResponseKeepaliveSeconds();

        if (keepAliveSeconds == 0) {
            return 0;
        }

        return getResponseHeaderRecvTime() + (keepAliveSeconds * 1000);
    }

    public int getResponseKeepaliveSeconds() throws IOException {
        receiveMode();
        // Case 1: In HTTP 1.1, Connection: Keep-Alive header may not be
        // forwarded by some reverse proxies, but Keep-Alive: timeout header
        // is forwarded.
        Integer timeout = responseHeaders.getHeaderSubValueInt("keep-alive", "timeout", ",");
        if (timeout != null) return timeout;

        // Case 2: Connection: Keep-Alive header presents
        String header = responseHeaders.getHeader("connection");
        if ((header == null) || (!header.equalsIgnoreCase("keep-alive"))) {
            return (0);
        }

        // Case 3: In HTTP 1.1, when xcc is talking to AWS ALB,
        // Connection: Keep-Alive header is forwared, but Keep-Alive: timeout
        // header is not forwarded. We assign hard coded default timeout value
        // according to the configuration in XDBC server.
       if (isUseHTTP() && commit) {
            return KEEP_ALIVE_TIME;
        } else if (isUseHTTP() && (!commit)) {
            // xcc is talking to a load balancer, xcc.httpcompliant=true, and within
            // a multi-statement transaction
            return KEEP_ALIVE_TIME_MULTI_TXN;
        } else {
            return 0;
        }
    }

    public String getResponseConnection() throws IOException {
        receiveMode();
        return responseHeaders.getResponseConnection();
    }

    public JsonNode getResponseJsonBody() throws IOException {
        receiveMode();
        ObjectMapper mapper = new ObjectMapper();
        return mapper.readTree(inStream);
    }

    // --------------------------------------------------------------

    public void suppressHeaders() {
        this.suppressHeaders = true;
    }

    // --------------------------------------------------------------

    private void receiveMode() throws IOException {
        flushRequest(true);

        checkCloseOutput();

        if (headersParsed) {
            return;
        }

        parseHeaders();
    }

    // If buffer fills up, content-length can't be computed, so we must issue a close
    // on the output side of the socket so the server will see the end of input.
    // Don't always do this, some code handles keep-alives separately.
    private void checkCloseOutput() throws IOException {
        if (!closeOutputIfNoContentLength) {
            return;
        }

        String connHeader = getRequestHeader("Connection"); // careful: case-sensitive

        if ((connHeader == null) || (!connHeader.equalsIgnoreCase("keep-alive"))) {
            if (channel instanceof SocketChannel) {
                SocketChannel sockChannel = (SocketChannel)channel;

                sockChannel.socket().shutdownOutput();
            } else if (channel instanceof SslByteChannel) {
                ((SslByteChannel)channel).close(false);
            }
        }
    }

    private void parseHeaders() throws IOException {
        long now = System.currentTimeMillis();

        logger.finer("parsing response headers");

        responseHeaders.parseResponseHeaders(inStream);

        // conditional for unit testing, never sent by the server
        if (responseHeaders.getHeader(RCV_TIME_HEADER) == null) {
            responseHeaders.setHeader(RCV_TIME_HEADER, "" + now);
        }

        headersParsed = true;
    }

    public String getServerVersion() throws IOException {
        receiveMode();
        String header = getResponseHeader("server");
        if ((header != null) && (header.startsWith("MarkLogic "))) {
        	return header.substring(10);
    	}
        /*
        String responseLine = responseHeaders.getResponseLine();
        if ((responseLine != null) && responseLine.startsWith("XDBC/")) {
            return responseLine.substring(5).split(" ")[0];
        }
        */
        return null;
    }
    
    // --------------------------------------------------------------

    private void flushRequest(boolean finished) throws IOException {
        if (!headersWritten) {
            if (finished) {
                if (!isChunked()) {
                    setRequestContentLength(bodyBuffer.position());
                }
                if (!isKeepAlive()) {
                    setRequestHeader("Connection", "keep-alive");
                }
            }

            writeHeaders();
        }

        writeBody();
    }

    private void writeBody() throws IOException {
        bodyBuffer.flip();

        writeBuffer(channel, bodyBuffer);

        bodyBuffer.clear();
    }

    private void writeHeaders() throws IOException {
        if (!suppressHeaders) {
            byte[] headerBytes = requestHeaders.toString().getBytes("UTF-8");
            ByteBuffer headersBuffer = ByteBuffer.wrap(headerBytes);

            writeBuffer(channel, headersBuffer);
        }

        headersWritten = true;
    }

    private void writeBuffer(ByteChannel channel, ByteBuffer buffer) throws IOException {
        while (buffer.hasRemaining()) {
            channel.write(buffer);
        }
    }

    // --------------------------------------------------------------

    // On some JVMs, specifically IBM's, direct buffers are not GC'ed
    // properly.  If allocation of a direct buffer fails, try a regular one.
    ByteBuffer allocBuffer(int size) {
        bufSize = (size <= 0) ? DEFAULT_BUFFER_SIZE : size;

        bufSize = Math.max(bufSize, MINIMUM_BUFFER_SIZE);
        bufSize = Math.min(bufSize, MAXIMUM_BUFFER_SIZE);

        try {
            return ByteBuffer.allocateDirect(bufSize);
        } catch (OutOfMemoryError e) {
            return ByteBuffer.allocate(bufSize);
        }
    }
    
    public int getBufferSize() {
        return bufSize;
    }

    // --------------------------------------------------------------

    private static class ChannelInputStream extends InputStream {
        private static final int DIRECT_READ_THRESHOLD = 8 * 1024;
        private final ReadableByteChannel channel;
        private final ByteBuffer buffer;
        private int timeoutMillis;
        private Selector selector = null;

        public ChannelInputStream(ReadableByteChannel channel, ByteBuffer buffer, int timeoutMillis) {
            this.channel = channel;
            this.buffer = buffer.duplicate();
            this.timeoutMillis = timeoutMillis;

            this.buffer.clear();
            this.buffer.flip();
        }

        @Override
        public int read(byte bytes[], int off, int len) throws IOException {
            if (len == 0)
                return 0;

            if ((off < 0) || (off > bytes.length) || (len < 0) || ((off + len) > bytes.length) || ((off + len) < 0)) {
                throw new IndexOutOfBoundsException();
            }

            int rc = attemptCopyOut(bytes, off, len);

            if (rc != 0)
                return rc;

            if (len >= DIRECT_READ_THRESHOLD) {
                ByteBuffer buffer = ByteBuffer.wrap(bytes, off, len);

                buffer.position(off);
                buffer.limit(Math.min(off + len, buffer.capacity()));

                return (channel.read(buffer));
            }

            rc = fillBuffer();

            if (rc < 0)
                return -1;

            return attemptCopyOut(bytes, off, len);
        }

        @Override
        public int read(byte b[]) throws IOException {
            return read(b, 0, b.length);
        }

        @Override
        public int read() throws IOException {
            if (buffer.hasRemaining()) {
                return (buffer.get() & 0xff);
            }

            byte[] buf = new byte[1];
            int rc = read(buf, 0, 1);

            if (rc == -1)
                return -1;

            return buf[0] & 255;
        }

        private int attemptCopyOut(byte[] bytes, int off, int len) {
            int bufferedCount = buffer.remaining();
            int toRead = (bufferedCount < len) ? bufferedCount : len;

            if (toRead != 0) {
                buffer.get(bytes, off, toRead);
            }

            return toRead;
        }

        private int fillBuffer() throws IOException {
            buffer.clear();
            int rc = timedRead(buffer);
            buffer.flip();

            return rc;
        }

        private int timedRead(ByteBuffer buffer) throws IOException {
            if (channel instanceof SslByteChannel) {
                SslByteChannel ch = (SslByteChannel)channel;
                int tmp = ch.getTimeout();
                ch.setTimeout(timeoutMillis);
                try {
                    return ch.read(buffer);
                } finally {
                    ch.setTimeout(tmp);
                }
            }

            if ((timeoutMillis <= 0) || (!(channel instanceof SelectableChannel))) {
                return channel.read(buffer);
            }

            SelectableChannel schannel = (SelectableChannel)channel;

            synchronized (channel) {
                SelectionKey key = null;

                if (selector == null) {
                    selector = Selector.open();
                }

                try {
                    selector.selectNow(); // Needed to clear old key state
                    schannel.configureBlocking(false);
                    key = schannel.register(selector, SelectionKey.OP_READ);

                    selector.select(timeoutMillis);

                    int rc = channel.read(buffer);

                    if (rc == 0) {
                        throw new IOException("Timeout waiting for read (" + timeoutMillis + " milliseconds)");
                    }

                    return rc;
                } finally {
                    if (key != null)
                        key.cancel();
                    schannel.configureBlocking(true);
                }
            }
        }
    }

    // --------------------------------------------------------------
    /*
     * Helper functions for building http path with provided server path and
     * base path. Format: /basePath/serverPath.
     * Example server paths for EvalRequestController: "/eval", "/invoke"
     * Example server paths for ContentInsertController: "/"
     * Example http paths: "/example/base/path/eval"
     */
    public static String buildHttpPath(String serverPath, String basePath) {
        StringBuffer sb = new StringBuffer();
        if (HttpChannel.isUseHTTP()) {
            addPathSegment(sb, basePath);
        }
        addPathSegment(sb, serverPath);
        return sb.toString();
    }

    /*
     * Add http path segments to existing StringBuffer
     * Examples:
     * existing string      segment     output
     *      /               /eval       /eval
     *      /base/path/     /eval       /base/path/eval
     *      /base/path      eval        /base/path/eval
     */
    public static void addPathSegment(StringBuffer sb, String seg) {
        boolean endWithSlash = sb.length() > 0 &&
            sb.charAt(sb.length()-1) == '/';
        if (seg != null && seg.trim().length() > 0){
            if (seg.startsWith("/")) {
                if(endWithSlash) sb.append(seg.substring(1));
                else sb.append(seg);
            }  else {
                if (!endWithSlash) sb.append("/");
                sb.append(seg);
            }
        }
    }

    /**
     * Utilities for generating POST request body for multipart/form-data
     * Example boundary: ------------------------933ed83bed2bbd3c
     * Example form body:
     *  --------------------------933ed83bed2bbd3c
     *  Content-Disposition: form-data; name="key"
     *
     * -Z30HIE1I9xvUfwcXyLlHA==
     * --------------------------933ed83bed2bbd3c
     * Content-Disposition: form-data; name="grant_type"
     *
     * apikey
     * --------------------------933ed83bed2bbd3c--
     */
    public static class MultipartFormBody{
        private final static char[] MULTIPART_CHARS =
            "-_1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
                .toCharArray();

        private final String boundary;
        private Map nameValuePairs = new LinkedHashMap<>();

        public MultipartFormBody() {
            boundary = generateBoundary();
        }

        public void addTextBody(String name, String value) {
            nameValuePairs.put("\"" + name + "\"", value);
        }

        public String buildFormBody() {
            StringBuilder buffer = new StringBuilder();
            for(Map.Entry entry : nameValuePairs.entrySet()) {
                buffer.append("--").append(boundary).append("\r\n");
                buffer.append("Content-Disposition: form-data; name=").
                    append(entry.getKey()).append("\r\n");
                buffer.append("\r\n");
                buffer.append(entry.getValue()).append("\r\n");
            }
            buffer.append("--").append(boundary).append("--").append("\r\n");
            return buffer.toString();
        }

        public String getBoundary() {
            return boundary;
        }

        private String generateBoundary() {
            StringBuilder buffer = new StringBuilder();
            Random rand = new Random();
            int count = rand.nextInt(11) + 30; // a random size from 30 to 40
            for (int i = 0; i < count; i++) {
                buffer.append(MULTIPART_CHARS[rand.nextInt(MULTIPART_CHARS.length)]);
            }
            return buffer.toString();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy