All Downloads are FREE. Search and download functionalities are using the official Maven repository.

broker.sqs_test.go Maven / Gradle / Ivy

The newest version!
package main

import (
	"bytes"
	"context"
	"errors"
	"log"
	"strconv"
	"sync"
	"testing"
	"time"

	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/service/sqs"
	"github.com/aws/aws-sdk-go-v2/service/sqs/types"
	"github.com/golang/mock/gomock"
	. "github.com/smartystreets/goconvey/convey"
	"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
)

func TestSQS(t *testing.T) {

	Convey("Context", t, func() {
		buf := new(bytes.Buffer)
		ipcCtx := NewBrokerContext(log.New(buf, "", 0))
		i := &IPC{ipcCtx}

		var logBuffer bytes.Buffer
		log.SetOutput(&logBuffer)

		Convey("Responds to SQS client offers...", func() {
			ctrl := gomock.NewController(t)
			mockSQSClient := sqsclient.NewMockSQSClient(ctrl)

			brokerSQSQueueName := "example-name"
			responseQueueURL := aws.String("https://sqs.us-east-1.amazonaws.com/testing")

			runSQSHandler := func(sqsHandlerContext context.Context) {
				mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqs.CreateQueueInput{
					QueueName: aws.String(brokerSQSQueueName),
					Attributes: map[string]string{
						"MessageRetentionPeriod": strconv.FormatInt(int64((5 * time.Minute).Seconds()), 10),
					},
				}).Return(&sqs.CreateQueueOutput{
					QueueUrl: responseQueueURL,
				}, nil).Times(1)
				sqsHandler, err := newSQSHandler(sqsHandlerContext, mockSQSClient, brokerSQSQueueName, "example-region", i)
				So(err, ShouldBeNil)
				go sqsHandler.PollAndHandleMessages(sqsHandlerContext)
			}

			messageBody := aws.String("1.0\n{\"offer\": \"fake\", \"nat\": \"unknown\"}")
			receiptHandle := "fake-receipt-handle"
			sqsReceiveMessageInput := sqs.ReceiveMessageInput{
				QueueUrl:            responseQueueURL,
				MaxNumberOfMessages: 10,
				WaitTimeSeconds:     15,
				MessageAttributeNames: []string{
					string(types.QueueAttributeNameAll),
				},
			}
			sqsDeleteMessageInput := sqs.DeleteMessageInput{
				QueueUrl:      responseQueueURL,
				ReceiptHandle: &receiptHandle,
			}

			Convey("by ignoring it if no client id specified", func(c C) {
				var wg sync.WaitGroup
				wg.Add(1)

				sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
				defer sqsCancelFunc()
				defer wg.Wait()
				mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(1).DoAndReturn(
					func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
						return &sqs.ReceiveMessageOutput{
							Messages: []types.Message{
								{
									Body:          messageBody,
									ReceiptHandle: &receiptHandle,
								},
							},
						}, nil
					},
				)
				mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).Times(1).Do(
					func(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) {
						defer wg.Done()
						c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.")
						mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes()
					},
				)
				runSQSHandler(sqsHandlerContext)
			})

			Convey("by doing nothing if an error occurs upon receipt of the message", func(c C) {
				var wg sync.WaitGroup
				wg.Add(2)

				sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
				defer sqsCancelFunc()
				defer wg.Wait()

				numTimes := 0
				// When ReceiveMessage is called for the first time, the error has not had a chance to be logged yet.
				// Therefore, we opt to wait for the second call because we are guaranteed that the error was logged
				// by then.
				mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(2).DoAndReturn(
					func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
						numTimes += 1
						if numTimes <= 2 {
							wg.Done()
							if numTimes == 2 {
								c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: encountered error while polling for messages: error")
							}
						}
						return nil, errors.New("error")
					},
				)
				runSQSHandler(sqsHandlerContext)
			})

			Convey("by attempting to create a new sqs queue...", func() {
				clientId := "fake-id"
				sqsCreateQueueInput := sqs.CreateQueueInput{
					QueueName: aws.String("snowflake-client-fake-id"),
				}

				expectReceiveMessageReturnsValidMessage := func(sqsHandlerContext context.Context) {
					mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn(
						func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
							return &sqs.ReceiveMessageOutput{
								Messages: []types.Message{
									{
										Body: messageBody,
										MessageAttributes: map[string]types.MessageAttributeValue{
											"ClientID": {StringValue: &clientId},
										},
										ReceiptHandle: &receiptHandle,
									},
								},
							}, nil
						},
					)
				}

				Convey("and does not attempt to send a message via SQS if queue creation fails.", func(c C) {
					var wg sync.WaitGroup
					wg.Add(2)

					sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
					defer sqsCancelFunc()
					defer wg.Wait()

					expectReceiveMessageReturnsValidMessage(sqsHandlerContext)
					mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(nil, errors.New("error")).AnyTimes()
					numTimes := 0
					mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).MinTimes(2).Do(
						func(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) {
							numTimes += 1
							if numTimes <= 2 {
								wg.Done()
								if numTimes == 2 {
									c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: error encountered when creating answer queue for client fake-id: error")
								}
							}
						},
					)
					runSQSHandler(sqsHandlerContext)
				})

				Convey("and responds with a proxy answer if available.", func(c C) {
					var wg sync.WaitGroup
					wg.Add(1)

					sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
					defer sqsCancelFunc()
					defer wg.Wait()

					expectReceiveMessageReturnsValidMessage(sqsHandlerContext)
					mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(&sqs.CreateQueueOutput{
						QueueUrl: responseQueueURL,
					}, nil).AnyTimes()
					mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes()
					numTimes := 0
					mockSQSClient.EXPECT().SendMessage(sqsHandlerContext, gomock.Any()).MinTimes(1).DoAndReturn(
						func(ctx context.Context, input *sqs.SendMessageInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) {
							numTimes += 1
							if numTimes == 1 {
								c.So(input.MessageBody, ShouldEqual, aws.String("{\"answer\":\"fake answer\"}"))
								// Ensure that match is correctly recorded in metrics
								ipcCtx.metrics.printMetrics()
								c.So(buf.String(), ShouldContainSubstring, `client-denied-count 0
client-restricted-denied-count 0
client-unrestricted-denied-count 0
client-snowflake-match-count 8
client-http-count 0
client-ampcache-count 0
client-sqs-count 8
`)
								wg.Done()
							}
							return &sqs.SendMessageOutput{}, nil
						},
					)
					runSQSHandler(sqsHandlerContext)

					snowflake := ipcCtx.AddSnowflake("fake", "", NATUnrestricted, 0)

					offer := <-snowflake.offerChannel
					So(offer.sdp, ShouldResemble, []byte("fake"))

					snowflake.answerChannel <- "fake answer"
				})
			})
		})

		Convey("Cleans up SQS client queues...", func() {
			brokerSQSQueueName := "example-name"
			responseQueueURL := aws.String("https://sqs.us-east-1.amazonaws.com/testing")

			ctrl := gomock.NewController(t)
			mockSQSClient := sqsclient.NewMockSQSClient(ctrl)

			runSQSHandler := func(sqsHandlerContext context.Context) {

				mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqs.CreateQueueInput{
					QueueName: aws.String(brokerSQSQueueName),
					Attributes: map[string]string{
						"MessageRetentionPeriod": strconv.FormatInt(int64((5 * time.Minute).Seconds()), 10),
					},
				}).Return(&sqs.CreateQueueOutput{
					QueueUrl: responseQueueURL,
				}, nil).Times(1)

				mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, gomock.Any()).AnyTimes().Return(
					&sqs.ReceiveMessageOutput{
						Messages: []types.Message{},
					}, nil,
				)

				sqsHandler, err := newSQSHandler(sqsHandlerContext, mockSQSClient, brokerSQSQueueName, "example-region", i)
				So(err, ShouldBeNil)
				// Set the cleanup interval to 1 ns so we can immediately test the cleanup logic
				sqsHandler.cleanupInterval = time.Nanosecond

				go sqsHandler.PollAndHandleMessages(sqsHandlerContext)
			}

			Convey("does nothing if there are no open queues.", func() {
				var wg sync.WaitGroup
				wg.Add(1)
				sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
				defer wg.Wait()

				mockSQSClient.EXPECT().ListQueues(sqsHandlerContext, &sqs.ListQueuesInput{
					QueueNamePrefix: aws.String("snowflake-client-"),
					MaxResults:      aws.Int32(1000),
					NextToken:       nil,
				}).DoAndReturn(func(ctx context.Context, input *sqs.ListQueuesInput, optFns ...func(*sqs.Options)) (*sqs.ListQueuesOutput, error) {
					wg.Done()
					// Cancel the handler context since we are only interested in testing one iteration of the cleanup
					sqsCancelFunc()
					return &sqs.ListQueuesOutput{
						QueueUrls: []string{},
					}, nil
				})

				runSQSHandler(sqsHandlerContext)
			})

			Convey("deletes open queue when there is one open queue.", func(c C) {
				var wg sync.WaitGroup
				wg.Add(1)
				sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())

				clientQueueUrl1 := "https://sqs.us-east-1.amazonaws.com/snowflake-client-1"
				clientQueueUrl2 := "https://sqs.us-east-1.amazonaws.com/snowflake-client-2"

				gomock.InOrder(
					mockSQSClient.EXPECT().ListQueues(sqsHandlerContext, &sqs.ListQueuesInput{
						QueueNamePrefix: aws.String("snowflake-client-"),
						MaxResults:      aws.Int32(1000),
						NextToken:       nil,
					}).Times(1).Return(&sqs.ListQueuesOutput{
						QueueUrls: []string{
							clientQueueUrl1,
							clientQueueUrl2,
						},
					}, nil),
					mockSQSClient.EXPECT().ListQueues(sqsHandlerContext, &sqs.ListQueuesInput{
						QueueNamePrefix: aws.String("snowflake-client-"),
						MaxResults:      aws.Int32(1000),
						NextToken:       nil,
					}).Times(1).DoAndReturn(func(ctx context.Context, input *sqs.ListQueuesInput, optFns ...func(*sqs.Options)) (*sqs.ListQueuesOutput, error) {
						// Executed on second iteration of cleanupClientQueues loop. This means that one full iteration has completed and we can verify the results of that iteration
						wg.Done()
						sqsCancelFunc()
						c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: finished running iteration of client queue cleanup. found and deleted 2 client queues.")
						return &sqs.ListQueuesOutput{
							QueueUrls: []string{},
						}, nil
					}),
				)

				gomock.InOrder(
					mockSQSClient.EXPECT().GetQueueAttributes(sqsHandlerContext, &sqs.GetQueueAttributesInput{
						QueueUrl:       aws.String(clientQueueUrl1),
						AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp},
					}).Times(1).Return(&sqs.GetQueueAttributesOutput{
						Attributes: map[string]string{
							string(types.QueueAttributeNameLastModifiedTimestamp): "0",
						}}, nil),

					mockSQSClient.EXPECT().GetQueueAttributes(sqsHandlerContext, &sqs.GetQueueAttributesInput{
						QueueUrl:       aws.String(clientQueueUrl2),
						AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp},
					}).Times(1).Return(&sqs.GetQueueAttributesOutput{
						Attributes: map[string]string{
							string(types.QueueAttributeNameLastModifiedTimestamp): "0",
						}}, nil),
				)

				gomock.InOrder(
					mockSQSClient.EXPECT().DeleteQueue(sqsHandlerContext, &sqs.DeleteQueueInput{
						QueueUrl: aws.String(clientQueueUrl1),
					}).Return(&sqs.DeleteQueueOutput{}, nil),
					mockSQSClient.EXPECT().DeleteQueue(sqsHandlerContext, &sqs.DeleteQueueInput{
						QueueUrl: aws.String(clientQueueUrl2),
					}).Return(&sqs.DeleteQueueOutput{}, nil),
				)

				runSQSHandler(sqsHandlerContext)
				wg.Wait()
			})
		})
	})
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy