All Downloads are FREE. Search and download functionalities are using the official Maven repository.

broker.sqs.go Maven / Gradle / Ivy

The newest version!
package main

import (
	"context"
	"log"
	"strconv"
	"strings"
	"time"

	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/service/sqs"
	"github.com/aws/aws-sdk-go-v2/service/sqs/types"
	"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
	"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
)

const (
	cleanupThreshold = -2 * time.Minute
)

type sqsHandler struct {
	SQSClient       sqsclient.SQSClient
	SQSQueueURL     *string
	IPC             *IPC
	cleanupInterval time.Duration
}

func (r *sqsHandler) pollMessages(ctx context.Context, chn chan<- *types.Message) {
	for {
		select {
		case <-ctx.Done():
			// if context is cancelled
			return
		default:
			res, err := r.SQSClient.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{
				QueueUrl:            r.SQSQueueURL,
				MaxNumberOfMessages: 10,
				WaitTimeSeconds:     15,
				MessageAttributeNames: []string{
					string(types.QueueAttributeNameAll),
				},
			})

			if err != nil {
				log.Printf("SQSHandler: encountered error while polling for messages: %v\n", err)
				continue
			}

			for _, message := range res.Messages {
				chn <- &message
			}
		}
	}
}

func (r *sqsHandler) cleanupClientQueues(ctx context.Context) {
	for range time.NewTicker(r.cleanupInterval).C {
		// Runs at fixed intervals to clean up any client queues that were last changed more than 2 minutes ago
		select {
		case <-ctx.Done():
			// if context is cancelled
			return
		default:
			queueURLsList := []string{}
			var nextToken *string
			for {
				res, err := r.SQSClient.ListQueues(ctx, &sqs.ListQueuesInput{
					QueueNamePrefix: aws.String("snowflake-client-"),
					MaxResults:      aws.Int32(1000),
					NextToken:       nextToken,
				})
				if err != nil {
					log.Printf("SQSHandler: encountered error while retrieving client queues to clean up: %v\n", err)
					// client queues will be cleaned up the next time the cleanup operation is triggered automatically
					break
				}
				queueURLsList = append(queueURLsList, res.QueueUrls...)
				if res.NextToken == nil {
					break
				} else {
					nextToken = res.NextToken
				}
			}

			numDeleted := 0
			cleanupCutoff := time.Now().Add(cleanupThreshold)
			for _, queueURL := range queueURLsList {
				if !strings.Contains(queueURL, "snowflake-client-") {
					continue
				}
				res, err := r.SQSClient.GetQueueAttributes(ctx, &sqs.GetQueueAttributesInput{
					QueueUrl:       aws.String(queueURL),
					AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp},
				})
				if err != nil {
					// According to the AWS SQS docs, the deletion process for a queue can take up to 60 seconds. So the queue
					// can be in the process of being deleted, but will still be returned by the ListQueues operation, but
					// fail when we try to GetQueueAttributes for the queue
					log.Printf("SQSHandler: encountered error while getting attribute of client queue %s. queue may already be deleted.\n", queueURL)
					continue
				}
				lastModifiedInt64, err := strconv.ParseInt(res.Attributes[string(types.QueueAttributeNameLastModifiedTimestamp)], 10, 64)
				if err != nil {
					log.Printf("SQSHandler: encountered invalid lastModifiedTimetamp value from client queue %s: %v\n", queueURL, err)
					continue
				}
				lastModified := time.Unix(lastModifiedInt64, 0)
				if lastModified.Before(cleanupCutoff) {
					_, err := r.SQSClient.DeleteQueue(ctx, &sqs.DeleteQueueInput{
						QueueUrl: aws.String(queueURL),
					})
					if err != nil {
						log.Printf("SQSHandler: encountered error when deleting client queue %s: %v\n", queueURL, err)
						continue
					} else {
						numDeleted += 1
					}

				}
			}
			log.Printf("SQSHandler: finished running iteration of client queue cleanup. found and deleted %d client queues.\n", numDeleted)
		}
	}
}

func (r *sqsHandler) handleMessage(context context.Context, message *types.Message) {
	var encPollReq []byte
	var response []byte
	var err error

	clientID := message.MessageAttributes["ClientID"].StringValue
	if clientID == nil {
		log.Println("SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.")
		return
	}

	res, err := r.SQSClient.CreateQueue(context, &sqs.CreateQueueInput{
		QueueName: aws.String("snowflake-client-" + *clientID),
	})
	if err != nil {
		log.Printf("SQSHandler: error encountered when creating answer queue for client %s: %v\n", *clientID, err)
		return
	}
	answerSQSURL := res.QueueUrl

	encPollReq = []byte(*message.Body)
	arg := messages.Arg{
		Body:             encPollReq,
		RemoteAddr:       "",
		RendezvousMethod: messages.RendezvousSqs,
	}
	err = r.IPC.ClientOffers(arg, &response)

	if err != nil {
		log.Printf("SQSHandler: error encountered when handling message: %v\n", err)
		return
	}

	r.SQSClient.SendMessage(context, &sqs.SendMessageInput{
		QueueUrl:    answerSQSURL,
		MessageBody: aws.String(string(response)),
	})
}

func (r *sqsHandler) deleteMessage(context context.Context, message *types.Message) {
	r.SQSClient.DeleteMessage(context, &sqs.DeleteMessageInput{
		QueueUrl:      r.SQSQueueURL,
		ReceiptHandle: message.ReceiptHandle,
	})
}

func newSQSHandler(context context.Context, client sqsclient.SQSClient, sqsQueueName string, region string, i *IPC) (*sqsHandler, error) {
	// Creates the queue if a queue with the same name doesn't exist. If a queue with the same name and attributes
	// already exists, then nothing will happen. If a queue with the same name, but different attributes exists, then
	// an error will be returned
	res, err := client.CreateQueue(context, &sqs.CreateQueueInput{
		QueueName: aws.String(sqsQueueName),
		Attributes: map[string]string{
			"MessageRetentionPeriod": strconv.FormatInt(int64((5 * time.Minute).Seconds()), 10),
		},
	})

	if err != nil {
		return nil, err
	}

	return &sqsHandler{
		SQSClient:       client,
		SQSQueueURL:     res.QueueUrl,
		IPC:             i,
		cleanupInterval: time.Second * 30,
	}, nil
}

func (r *sqsHandler) PollAndHandleMessages(ctx context.Context) {
	log.Println("SQSHandler: Starting to poll for messages at: " + *r.SQSQueueURL)
	messagesChn := make(chan *types.Message, 2)
	go r.pollMessages(ctx, messagesChn)
	go r.cleanupClientQueues(ctx)

	for message := range messagesChn {
		select {
		case <-ctx.Done():
			// if context is cancelled
			return
		default:
			r.handleMessage(ctx, message)
			r.deleteMessage(ctx, message)
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy