/*
 * Decompiled with CFR 0.152.
 */
package it.auties.whatsapp.crypto;

import it.auties.whatsapp.crypto.Hkdf;
import it.auties.whatsapp.model.sync.LTHashState;
import it.auties.whatsapp.model.sync.RecordSync;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;

public class LTHash {
    private static final int EXPAND_SIZE = 128;
    private final byte @NonNull [] salt = "WhatsApp Patch Integrity".getBytes(StandardCharsets.UTF_8);
    private final byte @NonNull [] hash;
    @NonNull
    private final Map<String, byte[]> indexValueMap;
    @NonNull
    private final List<byte[]> add;
    @NonNull
    private final List<byte[]> subtract;

    public LTHash(LTHashState hash) {
        this.hash = hash.hash();
        this.indexValueMap = new HashMap<String, byte[]>(hash.indexValueMap());
        this.add = new ArrayList<byte[]>();
        this.subtract = new ArrayList<byte[]>();
    }

    public void mix(byte[] indexMac, byte[] valueMac, RecordSync.Operation operation) {
        String indexMacBase64 = Base64.getEncoder().encodeToString(indexMac);
        byte[] prevOp = this.indexValueMap.get(indexMacBase64);
        if (operation == RecordSync.Operation.REMOVE) {
            if (prevOp == null) {
                return;
            }
            this.indexValueMap.remove(indexMacBase64, prevOp);
        } else {
            this.add.add(valueMac);
            this.indexValueMap.put(indexMacBase64, valueMac);
        }
        if (prevOp != null) {
            this.subtract.add(prevOp);
        }
    }

    public Result finish() {
        byte[] subtracted = this.perform(this.hash, false);
        byte[] added = this.perform(subtracted, true);
        return new Result(added, this.indexValueMap);
    }

    private byte[] perform(byte[] input, boolean sum) {
        for (byte[] item : sum ? this.add : this.subtract) {
            input = this.perform(input, item, sum);
        }
        return input;
    }

    private byte[] perform(byte[] input, byte[] buffer, boolean sum) {
        byte[] expanded = Hkdf.extractAndExpand(buffer, this.salt, 128);
        ByteBuffer eRead = ByteBuffer.wrap(input).order(ByteOrder.LITTLE_ENDIAN);
        ByteBuffer tRead = ByteBuffer.wrap(expanded).order(ByteOrder.LITTLE_ENDIAN);
        ByteBuffer write = ByteBuffer.allocate(input.length).order(ByteOrder.LITTLE_ENDIAN);
        for (int index = 0; index < input.length; index += 2) {
            int first = Short.toUnsignedInt(eRead.getShort(index));
            int second = Short.toUnsignedInt(tRead.getShort(index));
            write.putShort(index, (short)(sum ? first + second : first - second));
        }
        byte[] result = new byte[input.length];
        write.get(result);
        return result;
    }

    public record Result(byte[] hash, Map<String, byte[]> indexValueMap) {
    }
}

