:- module(chacha20_poly1305, [
    chacha20_poly1305_encrypt/7, 
    chacha20_poly1305_decrypt/7,
    test_chacha20_poly1305/0
]).
:- use_module(poly1305).

% mostly from https://datatracker.ietf.org/doc/html/rfc8439#section-2.8
chacha20_poly1305_encrypt(Aad, Key, Nonce, CounterOffset, Text, Ciphertext, Tag) :-
    assertion(is_u8s(Aad)),
    assertion(is_chacha20_key(Key)),
    assertion(is_chacha20_nonce(Nonce)),
    assertion(is_u8s(Text)),

    poly1305_generate_key(Key, Nonce, CounterOffset, Poly1305Key),

    Counter is CounterOffset + 1,
    chacha20_cipher(Key, Nonce, Counter, Text, Ciphertext),

    create_mac_data(Aad, Ciphertext, MacData),
    poly1305(Poly1305Key, MacData, Tag),

    assertion(is_u8s(Ciphertext)),
    assertion(is_poly1305_tag(Tag)).

chacha20_poly1305_decrypt(Aad, Key, Nonce, CounterOffset, Text, Ciphertext, Tag) :-
    assertion(is_u8s(Aad)),
    assertion(is_chacha20_key(Key)),
    assertion(is_chacha20_nonce(Nonce)),
    assertion(is_u8s(Ciphertext)),
    assertion(is_poly1305_tag(Tag)),

    poly1305_generate_key(Key, Nonce, CounterOffset, Poly1305Key),

    create_mac_data(Aad, Ciphertext, MacData),
    poly1305(Poly1305Key, MacData, ExpectedTag),

    assertion(ExpectedTag = Tag),
    Counter is CounterOffset + 1,
    chacha20_cipher(Key, Nonce, Counter, Ciphertext, Text),

    assertion(is_u8s(Text)).

create_mac_data(Aad, Ciphertext, MacData) :-
    pad16(Aad, Aad16),
    pad16(Ciphertext, Ciphertext16),
    length(Aad, AadLength),
    length(Ciphertext, CiphertextLength),
    as_int_le(8, AadLengthBytes, AadLength),
    as_int_le(8, CiphertextLengthBytes, CiphertextLength),
    append([
        Aad, Aad16, 
        Ciphertext, Ciphertext16,
        AadLengthBytes, CiphertextLengthBytes
    ], MacData).

pad16(X, Padding) :-
    length(X, Len),
    PaddingNeeded is (16 - Len mod 16) mod 16,
    zeros(PaddingNeeded, Padding).

poly1305_generate_key(Key, Nonce, CounterOffset, Poly1305Key) :-
    chacha20_block(Key, Nonce, CounterOffset, _, Mixed),
    u8s_to_u32s_le(Bytes, Mixed),
    assertion(length(Bytes, 64)),
    length(Poly1305Key, 32),
    prefix(Poly1305Key, Bytes).

test_chacha20_poly1305 :-
    test_pad16,
    test_poly1305_generate_key,
    test_e2e.

test_pad16 :-
    assertion( pad16([], []) ),
    assertion( pad16([1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) ).

test_poly1305_generate_key :-
    hex_bytes("808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f", Key),
    hex_bytes("0001020304050607", Nonce),
    hex_bytes("8ad5a08b905f81cc815040274ab29471a833b637e3fd0da508dbb8e2fdd1a646", ExpectedPolyKey),
    poly1305_generate_key(Key, Nonce, 0, PolyKey),
    assertion(PolyKey = ExpectedPolyKey).


test_e2e :-
    hex_bytes("50515253c0c1c2c3c4c5c6c7", Aad),
    hex_bytes("808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f", Key),
    hex_bytes("4041424344454647", Nonce),
    string_bytes(
        "Ladies and Gentlemen of the class of \x27\99: If I could offer you only one tip for the future, sunscreen would be it.",
        Text, utf8
    ),
    CounterOffset is 7 << 32,  % rfc8439 expects a 3 byte nonce, so we sneak it into the counter like this so we can reuse the test vector
    poly1305_generate_key(Key, Nonce, CounterOffset, PolyKey),
    chacha20_poly1305_encrypt(Aad, Key, Nonce, CounterOffset, Text, Ciphertext, Tag),

    hex_bytes("7bac2b252db447af09b67a55a4e955840ae1d6731075d9eb2a9375783ed553ff", ExpectedPolyKey),

    hex_bytes("d31a8d34648e60db7b86afbc53ef7ec2a4aded51296e08fea9e2b5a736ee62d63dbea45e8ca9671282fafb69da92728b1a71de0a9e060b2905d6a5b67ecd3b3692ddbd7f2d778b8c9803aee328091b58fab324e4fad675945585808b4831d7bc3ff4def08e4b7a9de576d26586cec64b6116", ExpectedCiphertext),
    hex_bytes("1ae10b594f09e26a7e902ecbd0600691", ExpectedTag),

    assertion(PolyKey = ExpectedPolyKey),
    assertion(Ciphertext = ExpectedCiphertext),
    assertion(Tag = ExpectedTag),

    chacha20_poly1305_decrypt(
        Aad, Key, Nonce, CounterOffset, Plaintext, ExpectedCiphertext, ExpectedTag 
    ),
    assertion(Plaintext = Text).