All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.mike10004.vhs.BasicHeuristic Maven / Gradle / Ivy

package io.github.mike10004.vhs;

import com.google.common.collect.ImmutableMultiset;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multiset;
import com.google.common.io.ByteSource;
import com.google.common.net.HttpHeaders;
import com.google.common.net.MediaType;
import io.github.mike10004.vhs.harbridge.HttpMethod;
import io.github.mike10004.vhs.harbridge.ParsedRequest;
import io.github.mike10004.vhs.repackaged.org.apache.http.NameValuePair;
import io.github.mike10004.vhs.repackaged.org.apache.http.client.utils.URLEncodedUtils;

import javax.annotation.Nullable;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
 * Implementation of a heuristic that compares request headers, parameters, and bodies.
 * This is inspired by the JavaScript implementation of https://github.com/Stuk/server-replay.
 */
public class BasicHeuristic implements Heuristic {

    private static final Charset DEFAULT_FORM_DATA_CHARSET = StandardCharsets.ISO_8859_1;

    public static final int DEFAULT_THRESHOLD_EXCLUSIVE = 0;

    private static final int DEFAULT_INCREMENT = 100;
    private final int increment;
    private final int halfIncrement;

    public BasicHeuristic() {
        this(DEFAULT_INCREMENT);
    }

    public BasicHeuristic(int increment) {
        this.increment = increment;
        this.halfIncrement = this.increment / 2;
    }

    @Override
    public int rate(ParsedRequest entryRequest, ParsedRequest request) {
        // String name;
        URI requestUrl = request.url;
        Multimap requestHeaders = request.indexedHeaders;
        @Nullable Multimap> requestQuery = request.query;
        // method, host and pathname must match
        if (requestUrl == null) {
            return 0;
        }
        if (!entryRequest.method.equals(request.method)) { // TODO replace with enum != enum
            return 0;
        }
        if (!entryRequest.url.getHost().equals(requestUrl.getHost())) {
            return 0;
        }
        if (!entryRequest.url.getPath().equals(requestUrl.getPath())) {
            return 0;
        }
        int points = increment; // One point for matching above requirements

        // each query
        @Nullable Multimap> entryQuery = entryRequest.query;
        if (entryQuery != null && requestQuery != null) {
            for (String name : requestQuery.keySet()) {
                if (!entryQuery.containsKey(name)) {
                    points -= halfIncrement;
                } else {
                    points += stripProtocolFromOptionals(entryQuery.get(name)).equals(stripProtocolFromOptionals(requestQuery.get(name))) ? increment : 0;
                }
            }

            for (String name : entryQuery.keySet()) {
                if (!requestQuery.containsKey(name)) {
                    points -= halfIncrement;
                }
            }
        }

        // each header
        Multimap entryHeaders = entryRequest.indexedHeaders;
        for (String name : requestHeaders.keySet()) {
            if (entryHeaders.containsKey(name)) {
                points += stripProtocolFromStrings(entryHeaders.get(name)).equals(stripProtocolFromStrings(requestHeaders.get(name))) ? increment : 0;
            }
            // TODO handle missing headers and adjust score appropriately
        }

        if (request.method == HttpMethod.POST || request.method == HttpMethod.PUT) {
            if (!request.isBodyPresent() && !entryRequest.isBodyPresent()) {
                points += halfIncrement;
            } else if (request.isBodyPresent() && entryRequest.isBodyPresent()) {
                if (isSameBody(entryRequest, request)) {
                    points += increment;
                }
            }
        }

        return points;
    }

    static boolean isSameBody(ParsedRequest entryRequest, ParsedRequest request) {
        ByteSource requestBody = bodyToByteSource(request);
        ByteSource entryBody = bodyToByteSource(entryRequest);
        @Nullable String entryContentType = entryRequest.getFirstHeaderValue(HttpHeaders.CONTENT_TYPE);
        @Nullable String requestContentType = request.getFirstHeaderValue(HttpHeaders.CONTENT_TYPE);
        return isSameBody(entryBody, entryContentType, requestBody, requestContentType);
    }

    @Nullable
    private static Multiset parseIfWwwFormData(ByteSource body, @Nullable String contentType) {
        if (contentType != null) {
            try {
                MediaType mediaType = MediaType.parse(contentType);
                if (MediaType.FORM_DATA.withoutParameters().equals(mediaType.withoutParameters())) {
                    Charset charset = mediaType.charset().or(DEFAULT_FORM_DATA_CHARSET);
                    String queryString = body.asCharSource(charset).read();
                    List> params = URLEncodedUtils.parse(queryString, charset);
                    if (params == null) {
                        return ImmutableMultiset.of();
                    }
                    return params.stream().map(NameValuePair::from)
                            .collect(ImmutableMultiset.toImmutableMultiset());
                }
            } catch (RuntimeException | IOException ignore) {
            }
        }
        return null;
    }

    static boolean isSameBody(ByteSource entryBody, @Nullable String entryContentType, ByteSource requestBody, @Nullable String requestContentType) {
        if (entryBody.sizeIfKnown().equals(requestBody.sizeIfKnown())) {
            if (entryBody.sizeIfKnown().isPresent() && entryBody.sizeIfKnown().get().equals(0L)) {
                return true;
            }
        }
        @Nullable Multiset entryParams = parseIfWwwFormData(entryBody, entryContentType);
        if (entryParams != null) {
            @Nullable Multiset requestParams = parseIfWwwFormData(requestBody, requestContentType);
            if (requestParams != null) {
                return entryParams.equals(requestParams);
            }
        }
        try {
            if (requestBody.contentEquals(entryBody)) {
                return true;
            }
        } catch (IOException ignore) {
        }
        return false;
    }

    private static ByteSource bodyToByteSource(ParsedRequest request) {
        return new ByteSource() {
            @Override
            public InputStream openStream() throws IOException {
                return request.openBodyStream();
            }
        };
    }

    private static Multiset stripProtocolFromStrings(Collection strings) {
        return strings.stream().map(string -> string.replaceAll("^https?", "")).collect(ImmutableMultiset.toImmutableMultiset());
    }

    private static Multiset> stripProtocolFromOptionals(Collection> strings) {
        return strings.stream()
                .map(stringOpt -> {
                    if (stringOpt.isPresent()) {
                        String string = stringOpt.get();
                        return Optional.of(string.replaceAll("^https?", ""));
                    } else {
                        return stringOpt;
                    }
                }).collect(ImmutableMultiset.toImmutableMultiset());
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy