All Downloads are FREE. Search and download functionalities are using the official Maven repository.

client.lib.rendezvous_sqs.go Maven / Gradle / Ivy

The newest version!
package snowflake_client

import (
	"context"
	"crypto/rand"
	"encoding/hex"
	"log"
	"net/http"
	"net/url"
	"regexp"
	"time"

	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/config"
	"github.com/aws/aws-sdk-go-v2/credentials"
	"github.com/aws/aws-sdk-go-v2/service/sqs"
	"github.com/aws/aws-sdk-go-v2/service/sqs/types"
	"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
)

type sqsRendezvous struct {
	transport   http.RoundTripper
	sqsClientID string
	sqsClient   sqsclient.SQSClient
	sqsURL      *url.URL
	timeout     time.Duration
	numRetries  int
}

func newSQSRendezvous(sqsQueue string, sqsAccessKeyId string, sqsSecretKey string, transport http.RoundTripper) (*sqsRendezvous, error) {
	sqsURL, err := url.Parse(sqsQueue)
	if err != nil {
		return nil, err
	}

	var id [8]byte
	_, err = rand.Read(id[:])
	if err != nil {
		log.Fatal(err)
	}
	clientID := hex.EncodeToString(id[:])

	queueURL := sqsURL.String()
	hostName := sqsURL.Hostname()

	regionRegex, _ := regexp.Compile(`^sqs\.([\w-]+)\.amazonaws\.com$`)
	res := regionRegex.FindStringSubmatch(hostName)
	if len(res) < 2 {
		log.Fatal("Could not extract AWS region from SQS URL. Ensure that the SQS Queue URL provided is valid.")
	}
	region := res[1]
	cfg, err := config.LoadDefaultConfig(context.TODO(),
		config.WithCredentialsProvider(
			credentials.NewStaticCredentialsProvider(sqsAccessKeyId, sqsSecretKey, ""),
		),
		config.WithRegion(region),
	)
	if err != nil {
		log.Fatal(err)
	}
	client := sqs.NewFromConfig(cfg)

	log.Println("Queue URL: ", queueURL)
	log.Println("SQS Client ID: ", clientID)

	return &sqsRendezvous{
		transport:   transport,
		sqsClientID: clientID,
		sqsClient:   client,
		sqsURL:      sqsURL,
		timeout:     time.Second,
		numRetries:  5,
	}, nil
}

func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) {
	log.Println("Negotiating via SQS Queue rendezvous...")

	_, err := r.sqsClient.SendMessage(context.TODO(), &sqs.SendMessageInput{
		MessageAttributes: map[string]types.MessageAttributeValue{
			"ClientID": {
				DataType:    aws.String("String"),
				StringValue: aws.String(r.sqsClientID),
			},
		},
		MessageBody: aws.String(string(encPollReq)),
		QueueUrl:    aws.String(r.sqsURL.String()),
	})
	if err != nil {
		return nil, err
	}

	time.Sleep(r.timeout) // wait for client queue to be created by the broker

	var responseQueueURL *string
	for i := 0; i < r.numRetries; i++ {
		// The SQS queue corresponding to the client where the SDP Answer will be placed
		// may not be created yet. We will retry up to 5 times before we error out.
		var res *sqs.GetQueueUrlOutput
		res, err = r.sqsClient.GetQueueUrl(context.TODO(), &sqs.GetQueueUrlInput{
			QueueName: aws.String("snowflake-client-" + r.sqsClientID),
		})
		if err != nil {
			log.Println(err)
			log.Printf("Attempt %d of %d to retrieve URL of response SQS queue failed.\n", i+1, r.numRetries)
			time.Sleep(r.timeout)
		} else {
			responseQueueURL = res.QueueUrl
			break
		}
	}
	if err != nil {
		return nil, err
	}

	var answer string
	for i := 0; i < r.numRetries; i++ {
		// Waiting for SDP Answer from proxy to be placed in SQS queue.
		// We will retry upt to 5 times before we error out.
		res, err := r.sqsClient.ReceiveMessage(context.TODO(), &sqs.ReceiveMessageInput{
			QueueUrl:            responseQueueURL,
			MaxNumberOfMessages: 1,
			WaitTimeSeconds:     20,
		})
		if err != nil {
			return nil, err
		}
		if len(res.Messages) == 0 {
			log.Printf("Attempt %d of %d to receive message from response SQS queue failed. No message found in queue.\n", i+1, r.numRetries)
			delay := float64(i)/2.0 + 1
			time.Sleep(time.Duration(delay*1000) * (r.timeout / 1000))
		} else {
			answer = *res.Messages[0].Body
			break
		}
	}

	return []byte(answer), nil
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy