:- module(chacha20_prim, [chacha20_prim_u32/2]).

chacha20_prim_u32(B0, B1) :- 
    assertion(is_chacha20_block(B0)),
    chacha(10, B0, B1).  

chacha(0, D, D).
chacha(N1) --> {N1 > 0, succ(N0, N1)}, round, chacha(N0).

round -->
    qround([0, 4, 8, 12]), qround([1, 5, 9, 13]),
    qround([2, 6, 10, 14]), qround([3, 7, 11, 15]),
    qround([0, 5, 10, 15]), qround([1, 6, 11, 12]),
    qround([2, 7, 8, 13]), qround([3, 4, 9, 14]).

qround([A, B, C, D]) -->
    at([A, B], qadd), at([D, A], qxor), at([D], rot_l32(16)),
    at([C, D], qadd), at([B, C], qxor), at([B], rot_l32(12)),
    at([A, B], qadd), at([D, A], qxor), at([D], rot_l32(8)),
    at([C, D], qadd), at([B, C], qxor), at([B], rot_l32(7)).

at([IxA], Pred, L0, LN) :- 
    nth0(IxA, L0, A0, LRem),
    call(Pred, A0, A1), trunc_32(A1, A2),
    nth0(IxA, LN, A2, LRem), !.

at([IxA, IxB], Pred, L0, LN) :- 
    nth0(IxB, L0, B),
    at([IxA], call(Pred, B), L0, LN).

qadd(B, A0, A1) :- A1 is A0 + B.
qxor(B, A0, A1) :- A1 is A0 xor B.
rot_l32(Amt, A0, A1) :- A1 is (A0 << Amt) \/ (A0 >> (32 - Amt)).