#!/usr/local/bin/python

from Crypto.Random import random
from Crypto.Cipher import AES
import sys

flag_1 = b"lol{thisisclearl"
flag_2 = b"ynotheflagsilly}"

try:
    with open("flag_1.txt", "rb") as f:
        flag_1 = f.read().strip()
    with open("flag_2.txt", "rb") as f:
        flag_2 = f.read().strip()
except:
    pass

key_size = 16

# size of cyclic group
n = 323509627153465998556191522930492845687
# generator of group
g = 5
a = random.randint(1, n-1)

A = pow(g, a, n)

print("A strange pigeon alights upon your window:")
print("Quoth the pigeon, ")
print(A)
input("(continue)")

print("Ponder about this number gets you nowhere, so you order your venerable wizard Qůirm to ponder it for you")
print("\"Elementary!\" Qůirm exclaims. This number is the result of modular exponentiation over a cyclic group!")
print("You don't know what that means, so you order Qůirm to draft the reply")
input("(continue)")

b = random.randint(1, n-1)
B = pow(g, b, n)
K1 = pow(A, b, n)
K1 = K1.to_bytes(length=n.bit_length() // 8, byteorder='big') # hopefully this is always 16 :p

print("Qůirm leans over and says to the pigeon:")
print(B)

print("The pigeon coos and flies away")
input("(continue)")

print("Shortly after, the pigeon returns again, this time carrying a message: ")

cipher = AES.new(K1, AES.MODE_ECB)
msg = cipher.encrypt(flag_1)
print(msg)

print("Qůirm gasps. You ask what the message says, but he replies \"there's no time to explain! just give me your response!\"", flush=True)

msg = None
while True:
    # read line
    # but with null bytes
    msg = sys.stdin.buffer.read(16)
    try:
        assert len(msg) == 16
    except:
        print("Qůirm looks at you and shakes his head.")
        print("\"This is a 16-byte pigeon\"")
        continue
    break
# sys.stdin.buffer.read()
print("Qůirm jots down your message, consoling you that he will encrypt it with perfect secrecy with a proprietary One-Time-Pigeon algorithm")

a += 1
b += 1
A = pow(g, a, n)
B = pow(g, b, n)
K2 = pow(A, b, n)
K2 = K2.to_bytes(length=n.bit_length() // 8, byteorder='big')

ciphertext = b""
for c, k in zip(msg, K2):
    ciphertext += (c ^ k).to_bytes()
print("Qůirm whispers to the pigeon... you can barely hear the phrase")
print(ciphertext)
input("(continue)")

a += 1
b += 1
A = pow(g, a, n)
B = pow(g, b, n)
K3 = pow(A, b, n)
K3 = K3.to_bytes(length=n.bit_length() // 8, byteorder='big')

cipher = AES.new(K3, AES.MODE_ECB)
msg = cipher.encrypt(flag_2)
print(msg)
