All Downloads are FREE. Search and download functionalities are using the official Maven repository.

it.auties.whatsapp.socket.AppStateHandler Maven / Gradle / Ivy

There is a newer version: 2.7.2
Show newest version
package it.auties.whatsapp.socket;

import it.auties.bytes.Bytes;
import it.auties.whatsapp.util.Protobuf;
import it.auties.whatsapp.binary.PatchType;
import it.auties.whatsapp.crypto.AesCbc;
import it.auties.whatsapp.crypto.Hmac;
import it.auties.whatsapp.crypto.LTHash;
import it.auties.whatsapp.exception.HmacValidationException;
import it.auties.whatsapp.model.action.*;
import it.auties.whatsapp.model.chat.Chat;
import it.auties.whatsapp.model.chat.ChatMute;
import it.auties.whatsapp.model.contact.Contact;
import it.auties.whatsapp.model.info.MessageIndexInfo;
import it.auties.whatsapp.model.info.MessageInfo;
import it.auties.whatsapp.model.request.Node;
import it.auties.whatsapp.model.setting.EphemeralSetting;
import it.auties.whatsapp.model.setting.LocaleSetting;
import it.auties.whatsapp.model.setting.PushNameSetting;
import it.auties.whatsapp.model.setting.UnarchiveChatsSetting;
import it.auties.whatsapp.model.sync.*;
import it.auties.whatsapp.util.*;
import lombok.NonNull;

import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static it.auties.whatsapp.api.ErrorHandler.Location.*;
import static java.lang.System.Logger.Level.WARNING;

class AppStateHandler {
    public static final int TIMEOUT = 120;
    private static final int PULL_ATTEMPTS = 3;
    private final SocketHandler socketHandler;
    private final Map attempts;
    private final OrderedAsyncTaskRunner runner;

    protected AppStateHandler(SocketHandler socketHandler) {
        this.socketHandler = socketHandler;
        this.attempts = new HashMap<>();
        this.runner = new OrderedAsyncTaskRunner();
    }

    protected CompletableFuture push(@NonNull PatchRequest patch) {
        return runner.runAsync(() -> pullUninterruptedly(List.of(patch.type()))
                .thenCompose(ignored -> sendPush(createPushRequest(patch)))
                .exceptionallyAsync(throwable -> socketHandler.handleFailure(PUSH_APP_STATE, throwable))
                .orTimeout(TIMEOUT, TimeUnit.SECONDS));
    }

    private PushRequest createPushRequest(PatchRequest patch) {
        try {
            var oldState = socketHandler.keys()
                    .findHashStateByName(patch.type())
                    .orElseGet(() -> new LTHashState(patch.type()));
            var newState = oldState.copy();
            var key = socketHandler.keys().appKey();
            var index = patch.index().getBytes(StandardCharsets.UTF_8);
            var actionData = ActionDataSync.builder()
                    .index(index)
                    .value(patch.sync())
                    .padding(new byte[0])
                    .version(patch.version())
                    .build();
            var encoded = Protobuf.writeMessage(actionData);
            var mutationKeys = MutationKeys.of(key.keyData().keyData());
            var encrypted = AesCbc.encryptAndPrefix(encoded, mutationKeys.encKey());
            var valueMac = generateMac(patch.operation(), encrypted, key.keyId().keyId(), mutationKeys.macKey());
            var indexMac = Hmac.calculateSha256(index, mutationKeys.indexKey());
            var generator = new LTHash(newState);
            generator.mix(indexMac, valueMac, patch.operation());
            var result = generator.finish();
            newState.hash(result.hash());
            newState.indexValueMap(result.indexValueMap());
            newState.version(newState.version() + 1);
            var syncId = new KeyId(key.keyId().keyId());
            var record = RecordSync.builder()
                    .index(new IndexSync(indexMac))
                    .value(new ValueSync(Bytes.of(encrypted, valueMac).toByteArray()))
                    .keyId(syncId)
                    .build();
            var mutation = MutationSync.builder().operation(patch.operation()).record(record).build();
            var snapshotMac = generateSnapshotMac(newState.hash(), newState.version(), patch.type(), mutationKeys.snapshotMacKey());
            var patchMac = generatePatchMac(snapshotMac, valueMac, newState.version(), patch.type(), mutationKeys.patchMacKey());
            var sync = PatchSync.builder()
                    .patchMac(patchMac)
                    .snapshotMac(snapshotMac)
                    .keyId(syncId)
                    .mutations(List.of(mutation))
                    .build();
            newState.indexValueMap().put(Bytes.of(indexMac).toBase64(), valueMac);
            return new PushRequest(patch, oldState, newState, sync);
        } catch (Throwable throwable) {
            throw new RuntimeException("Cannot create patch %s".formatted(patch), throwable);
        }
    }

    private CompletableFuture sendPush(PushRequest request) {
        var body = Node.ofChildren("collection", Map.of("name", request.patch().type(), "version", request.newState().version() - 1, "return_snapshot", false),
                Node.of("patch", Protobuf.writeMessage(request.sync())));
        return socketHandler.sendQuery("set", "w:sync:app:state", Node.ofChildren("sync", body))
                .thenAcceptAsync(this::parseSyncRequest)
                .thenRunAsync(() -> socketHandler.keys().putState(request.patch().type(), request.newState()))
                .thenRunAsync(() -> handleSyncRequest(request.patch()
                        .type(), request.sync(), request.oldState(), request.newState().version()));
    }

    private void handleSyncRequest(PatchType patchType, PatchSync patch, LTHashState oldState, long newVersion) {
        var patches = List.of(patch.withVersion(new VersionSync(newVersion)));
        var results = decodePatches(patchType, patches, oldState);
        results.records().forEach(this::processActions);
    }

    protected CompletableFuture pull(PatchType... patchTypes) {
        if (patchTypes == null || patchTypes.length == 0) {
            return CompletableFuture.completedFuture(null);
        }

        return runner.runAsync(() -> pullUninterruptedly(Arrays.asList(patchTypes))
                .thenAcceptAsync(success -> onPull(false, success))
                .exceptionallyAsync(exception -> onPullError(false, exception)));
    }

    protected CompletableFuture pullInitial() {
        if(socketHandler.store().initialSync()){
            return CompletableFuture.completedFuture(null);
        }

        return pullUninterruptedly(Arrays.asList(PatchType.values()))
                .thenAcceptAsync(success -> onPull(true, success))
                .exceptionallyAsync(exception -> onPullError(true, exception));
    }

    private void onPull(boolean initial, boolean success) {
        if (!socketHandler.store().initialSync()) {
            socketHandler.store().initialSync((initial && success) || isSyncComplete());
        }

        attempts.clear();
    }

    private boolean isSyncComplete() {
        return Arrays.stream(PatchType.values())
                .allMatch(this::isSyncComplete);
    }

    private boolean isSyncComplete(PatchType entry) {
        return socketHandler.keys()
                .findHashStateByName(entry)
                .filter(type -> type.version() > 0)
                .isPresent();
    }

    private Void onPullError(boolean initial, Throwable exception) {
        attempts.clear();
        if (initial) {
            return socketHandler.handleFailure(INITIAL_APP_STATE_SYNC, exception);
        }
        return socketHandler.handleFailure(PULL_APP_STATE, exception);
    }

    private CompletableFuture pullUninterruptedly(List patchTypes) {
        var tempStates = new HashMap();
        var nodes = getPullNodes(patchTypes, tempStates);
        return socketHandler.sendQuery("set", "w:sync:app:state", Node.ofChildren("sync", nodes))
                .thenApplyAsync(this::parseSyncRequest)
                .thenApplyAsync(records -> decodeSyncs(tempStates, records))
                .thenComposeAsync(this::handlePullResult)
                .orTimeout(TIMEOUT, TimeUnit.SECONDS);
    }

    private CompletableFuture handlePullResult(List remaining) {
        return remaining.isEmpty() ? CompletableFuture.completedFuture(true) : pullUninterruptedly(remaining);
    }

    private List getPullNodes(List patchTypes, Map tempStates) {
        return patchTypes.stream()
                .map(this::createStateWithVersion)
                .peek(state -> tempStates.put(state.name(), state))
                .map(LTHashState::toNode)
                .toList();
    }

    private LTHashState createStateWithVersion(PatchType name) {
        return socketHandler.keys().findHashStateByName(name).orElseGet(() -> new LTHashState(name));
    }

    private List decodeSyncs(Map tempStates, List records) {
        return records.stream()
                .map(record -> decodeSync(record, tempStates))
                .peek(chunk -> chunk.records().forEach(this::processActions))
                .filter(PatchChunk::hasMore)
                .map(PatchChunk::patchType)
                .toList();
    }

    private PatchChunk decodeSync(SnapshotSyncRecord record, Map tempStates) {
        try {
            var results = new ArrayList();
            if (record.hasSnapshot()) {
                var snapshot = decodeSnapshot(record.patchType(), record.snapshot());
                snapshot.ifPresent(decodedSnapshot -> {
                    results.addAll(decodedSnapshot.records());
                    tempStates.put(record.patchType(), decodedSnapshot.state());
                    socketHandler.keys().putState(record.patchType(), decodedSnapshot.state());
                });
            }
            if (record.hasPatches()) {
                var decodedPatches = decodePatches(record.patchType(), record.patches(), tempStates.get(record.patchType()));
                results.addAll(decodedPatches.records());
                socketHandler.keys().putState(record.patchType(), decodedPatches.state());
            }
            return new PatchChunk(record.patchType(), results, record.hasMore());
        } catch (Throwable throwable) {
            var hashState = new LTHashState(record.patchType());
            socketHandler.keys().putState(record.patchType(), hashState);
            attempts.put(record.patchType(), attempts.getOrDefault(record.patchType(), 0) + 1);
            if (attempts.get(record.patchType()) >= PULL_ATTEMPTS) {
                throw new RuntimeException("Cannot parse patch(%s tries)".formatted(PULL_ATTEMPTS), throwable);
            }
            return decodeSync(record, tempStates);
        }
    }

    private List parseSyncRequest(Node node) {
        return Stream.ofNullable(node)
                .map(sync -> sync.findNodes("sync"))
                .flatMap(Collection::stream)
                .map(sync -> sync.findNodes("collection"))
                .flatMap(Collection::stream)
                .map(this::parseSync)
                .flatMap(Optional::stream)
                .toList();
    }

    private Optional parseSync(Node sync) {
        var name = PatchType.of(sync.attributes().getString("name"));
        var type = sync.attributes().getString("type");
        if (Objects.equals(type, "error")) {
            return Optional.empty();
        }
        var more = sync.attributes().getBoolean("has_more_patches");
        var snapshotSync = sync.findNode("snapshot").flatMap(this::decodeSnapshot).orElse(null);
        var versionCode = sync.attributes().getInt("version");
        var patches = sync.findNode("patches")
                .orElse(sync)
                .findNodes("patch")
                .stream()
                .map(patch -> decodePatch(patch, versionCode))
                .flatMap(Optional::stream)
                .toList();
        return Optional.of(new SnapshotSyncRecord(name, snapshotSync, patches, more));
    }

    private Optional decodeSnapshot(Node snapshot) {
        return snapshot == null ? Optional.empty() : snapshot.contentAsBytes()
                .map(bytes -> Protobuf.readMessage(bytes, ExternalBlobReference.class))
                .map(Medias::download)
                .flatMap(CompletableFuture::join)
                .map(value -> Protobuf.readMessage(value, SnapshotSync.class));
    }

    private Optional decodePatch(Node patch, long versionCode) {
        if (!patch.hasContent()) {
            return Optional.empty();
        }
        var patchSync = Protobuf.readMessage(patch.contentAsBytes().orElseThrow(), PatchSync.class);
        if (!patchSync.hasVersion()) {
            var version = new VersionSync(versionCode + 1);
            patchSync.version(version);
        }
        return Optional.of(patchSync);
    }

    private void processActions(ActionDataSync mutation) {
        var value = mutation.value();
        if (value == null) {
            return;
        }
        var action = value.action();
        if (action != null) {
            var messageIndex = mutation.messageIndex();
            var targetContact = messageIndex.chatJid().flatMap(socketHandler.store()::findContactByJid);
            var targetChat = messageIndex.chatJid().flatMap(socketHandler.store()::findChatByJid);
            var targetMessage = targetChat.flatMap(chat -> socketHandler.store()
                    .findMessageById(chat, mutation.messageIndex().messageId().orElse(null)));
            switch (action) {
                case ClearChatAction clearChatAction -> clearMessages(targetChat.orElse(null), clearChatAction);
                case ContactAction contactAction ->
                        updateName(targetContact.orElseGet(() -> createContact(messageIndex)), targetChat.orElseGet(() -> createChat(messageIndex)), contactAction);
                case DeleteChatAction ignored -> targetChat.ifPresent(Chat::removeMessages);
                case DeleteMessageForMeAction ignored ->
                        targetMessage.ifPresent(message -> targetChat.ifPresent(chat -> deleteMessage(message, chat)));
                case MarkChatAsReadAction markAction ->
                        targetChat.ifPresent(chat -> chat.unreadMessagesCount(markAction.read() ? 0 : -1));
                case MuteAction muteAction ->
                        targetChat.ifPresent(chat -> chat.mute(ChatMute.muted(muteAction.muteEndTimestampSeconds())));
                case PinAction pinAction ->
                        targetChat.ifPresent(chat -> chat.pinnedTimestampSeconds(pinAction.pinned() ? (int) mutation.value().timestamp() : 0));
                case StarAction starAction -> targetMessage.ifPresent(message -> message.starred(starAction.starred()));
                case ArchiveChatAction archiveChatAction ->
                        targetChat.ifPresent(chat -> chat.archived(archiveChatAction.archived()));
                case TimeFormatAction timeFormatAction ->
                        socketHandler.store().twentyFourHourFormat(timeFormatAction.twentyFourHourFormatEnabled());
                default -> {}
            }
            socketHandler.onAction(action, messageIndex);
        }
        var setting = value.setting();
        if (setting != null) {
            switch (setting) {
                case EphemeralSetting ephemeralSetting -> showEphemeralMessageWarning(ephemeralSetting);
                case LocaleSetting localeSetting ->
                        socketHandler.updateLocale(localeSetting.locale(), socketHandler.store().locale());
                case PushNameSetting pushNameSetting ->
                        socketHandler.updateUserName(pushNameSetting.name(), socketHandler.store().name());
                case UnarchiveChatsSetting unarchiveChatsSetting ->
                        socketHandler.store().unarchiveChats(unarchiveChatsSetting.unarchiveChats());
                default -> {
                }
            }
            socketHandler.onSetting(setting);
        }
        var features = mutation.value().primaryFeature();
        if (features.isPresent() && !features.get().flags().isEmpty()) {
            socketHandler.onFeatures(features.get());
        }
    }

    private Chat createChat(MessageIndexInfo messageIndex) {
        var chat = messageIndex.chatJid().orElseThrow();
        return socketHandler.store().addNewChat(chat);
    }

    private Contact createContact(MessageIndexInfo messageIndex) {
        var chatJid = messageIndex.chatJid().orElseThrow();
        var contact = socketHandler.store().addContact(chatJid);
        socketHandler.onNewContact(contact);
        return contact;
    }

    private void showEphemeralMessageWarning(EphemeralSetting ephemeralSetting) {
        var logger = System.getLogger("AppStateHandler");
        logger.log(WARNING, "An ephemeral status update was received as a setting. " + "Data: %s".formatted(ephemeralSetting) + "This should not be possible." + " Open an issue on Github please");
    }

    private void clearMessages(Chat targetChat, ClearChatAction clearChatAction) {
        if (targetChat == null) {
            return;
        }

        if (clearChatAction.messageRange().isEmpty()) {
            targetChat.removeMessages();
            return;
        }

        clearChatAction.messageRange()
                .stream()
                .map(ActionMessageRangeSync::messages)
                .flatMap(Collection::stream)
                .map(SyncActionMessage::key)
                .filter(Objects::nonNull)
                .forEach(key -> targetChat.removeMessage(entry -> Objects.equals(entry.id(), key.id())));
    }

    private void updateName(Contact contact, Chat chat, ContactAction contactAction) {
        contactAction.fullName().ifPresent(contact::fullName);
        contactAction.firstName().ifPresent(contact::shortName);
        chat.name(contactAction.name());
    }

    private void deleteMessage(MessageInfo message, Chat chat) {
        chat.removeMessage(message);
        socketHandler.onMessageDeleted(message, false);
    }

    private SyncRecord decodePatches(PatchType name, List patches, LTHashState state) {
        var newState = state.copy();
        var results = patches.stream()
                .map(patch -> decodePatch(name, newState, patch))
                .map(MutationsRecord::records)
                .flatMap(Collection::stream)
                .toList();
        return new SyncRecord(newState, results);
    }

    private MutationsRecord decodePatch(PatchType patchType, LTHashState newState, PatchSync patch) {
        if (patch.hasExternalMutations()) {
            Medias.download(patch.externalMutations())
                    .join()
                    .ifPresent(blob -> handleExternalMutation(patch, blob));
        }

        newState.version(patch.encodedVersion());
        var syncMac = calculateSyncMac(patch, patchType);
        Validate.isTrue(syncMac.isEmpty() || Arrays.equals(syncMac.get(), patch.patchMac()), "sync_mac", HmacValidationException.class);
        var mutations = decodeMutations(patch.mutations(), newState);
        newState.hash(mutations.result().hash());
        newState.indexValueMap(mutations.result().indexValueMap());
        var snapshotMac = generatePatchMac(patchType, newState, patch);
        Validate.isTrue(snapshotMac.isEmpty() || Arrays.equals(snapshotMac.get(), patch.snapshotMac()), "patch_mac", HmacValidationException.class);
        return mutations;
    }

    private void handleExternalMutation(PatchSync patch, byte[] blob) {
        var mutationsSync = Protobuf.readMessage(blob, MutationsSync.class);
        patch.mutations().addAll(mutationsSync.mutations());
    }

    private Optional generatePatchMac(PatchType name, LTHashState newState, PatchSync patch) {
        return getMutationKeys(patch.keyId()).map(mutationKeys -> generateSnapshotMac(newState.hash(), newState.version(), name, mutationKeys.snapshotMacKey()));
    }

    private Optional calculateSyncMac(PatchSync patch, PatchType patchType) {
        return getMutationKeys(patch.keyId()).map(mutationKeys -> generatePatchMac(patch.snapshotMac(), getSyncMutationMac(patch), patch.encodedVersion(), patchType, mutationKeys.patchMacKey()));
    }

    private byte[] getSyncMutationMac(PatchSync patch) {
        return patch.mutations()
                .stream()
                .map(mutation -> mutation.record().value().blob())
                .map(Bytes::of)
                .map(binary -> binary.slice(-Spec.Signal.KEY_LENGTH))
                .reduce(Bytes.newBuffer(), Bytes::append)
                .toByteArray();
    }

    private Optional decodeSnapshot(PatchType name, SnapshotSync snapshot) {
        var mutationKeys = getMutationKeys(snapshot.keyId());
        if (mutationKeys.isEmpty()) {
            return Optional.empty();
        }
        var newState = new LTHashState(name, snapshot.version().version());
        var mutations = decodeMutations(snapshot.records(), newState);
        newState.hash(mutations.result().hash());
        newState.indexValueMap(mutations.result().indexValueMap());
        Validate.isTrue(Arrays.equals(snapshot.mac(), generateSnapshotMac(newState.hash(), newState.version(), name, mutationKeys.get()
                .snapshotMacKey())), "decode_snapshot", HmacValidationException.class);
        return Optional.of(new SyncRecord(newState, mutations.records()));
    }

    private Optional getMutationKeys(KeyId snapshot) {
        return socketHandler.keys()
                .findAppKeyById(snapshot.id())
                .map(AppStateSyncKey::keyData)
                .map(AppStateSyncKeyData::keyData)
                .map(MutationKeys::of);
    }

    private MutationsRecord decodeMutations(List syncs, LTHashState state) {
        var generator = new LTHash(state);
        var mutations = syncs.stream()
                .map(mutation -> decodeMutation(mutation.operation(), mutation.record(), generator))
                .flatMap(Optional::stream)
                .collect(Collectors.toList());
        return new MutationsRecord(generator.finish(), mutations);
    }

    private Optional decodeMutation(RecordSync.Operation operation, RecordSync sync, LTHash generator) {
        var mutationKeys = getMutationKeys(sync.keyId());
        if (mutationKeys.isEmpty()) {
            return Optional.empty();
        }
        var blob = Bytes.of(sync.value().blob());
        var encryptedBlob = blob.cut(-Spec.Signal.KEY_LENGTH).toByteArray();
        var encryptedMac = blob.slice(-Spec.Signal.KEY_LENGTH).toByteArray();
        Validate.isTrue(Arrays.equals(encryptedMac, generateMac(operation, encryptedBlob, sync.keyId()
                .id(), mutationKeys.get().macKey())), "decode_mutation", HmacValidationException.class);
        var result = AesCbc.decrypt(encryptedBlob, mutationKeys.get().encKey());
        var actionSync = Protobuf.readMessage(result, ActionDataSync.class);
        Validate.isTrue(Arrays.equals(sync.index().blob(), Hmac.calculateSha256(actionSync.index(), mutationKeys.get()
                .indexKey())), "decode_mutation", HmacValidationException.class);
        generator.mix(sync.index().blob(), encryptedMac, operation);
        return Optional.of(actionSync);
    }

    private byte[] generateMac(RecordSync.Operation operation, byte[] data, byte[] keyId, byte[] key) {
        var keyData = Bytes.of(operation.content()).append(keyId).toByteArray();
        var last = Bytes.newBuffer(Spec.Signal.MAC_LENGTH - 1).append(keyData.length).toByteArray();
        var total = Bytes.of(keyData, data, last).toByteArray();
        return Bytes.of(Hmac.calculateSha512(total, key)).cut(Spec.Signal.KEY_LENGTH).toByteArray();
    }

    private byte[] generateSnapshotMac(byte[] ltHash, long version, PatchType patchType, byte[] key) {
        var total = Bytes.of(ltHash)
                .append(BytesHelper.longToBytes(version))
                .append(patchType.toString().getBytes(StandardCharsets.UTF_8))
                .toByteArray();
        return Hmac.calculateSha256(total, key);
    }

    private byte[] generatePatchMac(byte[] snapshotMac, byte[] valueMac, long version, PatchType patchType, byte[] key) {
        var total = Bytes.of(snapshotMac)
                .append(valueMac)
                .append(BytesHelper.longToBytes(version))
                .append(patchType.toString().getBytes(StandardCharsets.UTF_8))
                .toByteArray();
        return Hmac.calculateSha256(total, key);
    }

    protected void dispose() {
        attempts.clear();
        runner.cancel();
    }

    private record SyncRecord(LTHashState state, List records) {
    }

    private record SnapshotSyncRecord(PatchType patchType, SnapshotSync snapshot, List patches,
                                      boolean hasMore) {
        public boolean hasSnapshot() {
            return snapshot != null;
        }

        public boolean hasPatches() {
            return patches != null && !patches.isEmpty();
        }
    }

    private record MutationsRecord(LTHash.Result result, List records) {
    }

    private record PatchChunk(PatchType patchType, List records, boolean hasMore) {
    }

    private record PushRequest(PatchRequest patch, LTHashState oldState, LTHashState newState, PatchSync sync) {
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy