All Downloads are FREE. Search and download functionalities are using the official Maven repository.

graphql.kickstart.autoconfigure.web.servlet.GraphQLWsServerEndpointRegistration Maven / Gradle / Ivy

The newest version!
package graphql.kickstart.autoconfigure.web.servlet;

import graphql.kickstart.servlet.GraphQLWebsocketServlet;
import java.util.ArrayList;
import java.util.List;
import jakarta.websocket.HandshakeResponse;
import jakarta.websocket.server.HandshakeRequest;
import jakarta.websocket.server.ServerEndpointConfig;
import org.springframework.context.Lifecycle;
import org.springframework.web.socket.server.standard.ServerEndpointRegistration;

/**
 * @author Andrew Potter
 */
public class GraphQLWsServerEndpointRegistration extends ServerEndpointRegistration
    implements Lifecycle {

  private static final String ALL = "*";
  private final GraphQLWebsocketServlet servlet;
  private final WsCsrfFilter csrfFilter;
  private final List allowedOrigins;

  public GraphQLWsServerEndpointRegistration(
      String path,
      GraphQLWebsocketServlet servlet,
      WsCsrfFilter csrfFilter,
      List allowedOrigins) {
    super(path, servlet);
    this.servlet = servlet;
    if (allowedOrigins == null || allowedOrigins.isEmpty()) {
      this.allowedOrigins = List.of(ALL);
    } else {
      this.allowedOrigins = new ArrayList<>(allowedOrigins);
    }
    this.csrfFilter = csrfFilter;
  }

  @Override
  public boolean checkOrigin(String originHeaderValue) {
    if (originHeaderValue == null || originHeaderValue.isBlank()) {
      return allowedOrigins.contains(ALL);
    }
    if (allowedOrigins.contains(ALL)) {
      return true;
    }
    String originToCheck = trimTrailingSlash(originHeaderValue);
    return allowedOrigins.stream()
        .map(this::trimTrailingSlash)
        .anyMatch(originToCheck::equalsIgnoreCase);
  }

  private String trimTrailingSlash(String origin) {
    return (origin.endsWith("/") ? origin.substring(0, origin.length() - 1) : origin);
  }

  @Override
  public void modifyHandshake(
      ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
    super.modifyHandshake(sec, request, response);
    csrfFilter.doFilter(request);
    servlet.modifyHandshake(sec, request, response);
  }

  @Override
  public void start() {
    // do nothing
  }

  @Override
  public void stop() {
    servlet.beginShutDown();
  }

  @Override
  public boolean isRunning() {
    return !servlet.isShutDown();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy