import Data.Bits (xor)
{- Setup for Exercise 1 -}
data EncMode = Encrypt | Decrypt

isAsciiUpper :: Char -> Bool
isAsciiUpper c = 'A' <= c && c <= 'Z'

-- Shift an uppercase character by n places in the alphabet
shift :: Int -> Char -> Char
shift n c
  | isAsciiUpper c = toEnum $ (fromEnum c - numA + n) `mod` 26 + numA
  | otherwise = c
  where
    numA = fromEnum 'A'

-- Encrypt or decrypt character c with key character k
vigenereTable :: EncMode -> (Char, Char) -> Char
vigenereTable Encrypt (c, k) = shift (fromEnum k - fromEnum 'A') c
vigenereTable Decrypt (c, k) = shift (fromEnum 'A' - fromEnum k) c

-- Encrypt or decrypt a string xs with key. For example, with key GHCUP, vigenere Encrypt "HELLOWORLD" "GHCUP" = "NLNFDCVTFS" and vigenere Decrypt "NLNFDCVTFS" "GHCUP" = "HELLOWORLD"
vigenere :: EncMode -> String -> String -> String
vigenere m xs key = vigenereChar xs "" (cycle key)
  where
    vigenereChar :: String -> String -> String -> String
    vigenereChar [] ys _ = ys
    vigenereChar (c : cs) ys (k : ks) = vigenereChar cs (ys ++ [vigenereTable m (c, k)]) newKey
      where
        newKey = if isAsciiUpper c then ks else k : ks

{- Setup for Exercise 2 -}
data Page = Page [String]

instance Show Page where
  show (Page p) = unlines p

testPage = Page ["Dear Page,", "", "Please show us", "some lines!", "", "Sincerely,", "X."]

infixl 6 <+>, <->
class ZPM a where
  zero :: a
  (<+>), (<->) :: a -> a -> a

instance ZPM Int where
  zero = 0
  (<+>) = (+)
  (<->) = (-)


{- Exercise 1.1 -}
freqs :: String -> [Int]
freqs xs = undefined

indexOfCoincidence :: String -> Float
indexOfCoincidence xs  = undefined

nSubstrings :: String -> Int -> [String]
nSubstrings xs n = undefined

rotate n xs = undefined

{- Exercise 1.2 -}
positions :: Eq a => a -> [a] -> [Int]
positions x xs = [i | (x', i) <- zip xs [0 ..], x' == x]

-- Naively return the keyword length in [1..8] that leads to the highest average index of coincidence
keywordLen :: String -> Int
keywordLen xs = 1 + head (positions (maximum iocs) iocs)
  where
    iocs = [(avgList (map indexOfCoincidence (nSubstrings xs n)), n) | n <- [1 .. 8]]
    avgList xs = sum xs / fromIntegral (length xs)

chisqr :: [Float] -> [Float] -> Float
chisqr os es = sum [(o - e) ^ 2 / e | (o, e) <- zip os es]

chisqrVals :: String -> [Float]
chisqrVals xs = undefined


crackVigenere :: String -> String
crackVigenere xs = vigenere Decrypt xs key
  where
    asciiXs = filter isAsciiUpper xs -- Filter out non-uppercase-letters
    keywordLength = keywordLen asciiXs -- Get keyword length
    minIndex xs = head $ positions (minimum xs) xs
    -- Get the letter corresponding to the best shift value for each letter in the key
    keywordLetter xs = shift (minIndex $ chisqrVals xs) 'A'
    key = map keywordLetter (nSubstrings asciiXs keywordLength)

 
-- Source: https://en.wikipedia.org/wiki/Letter_frequency
englishFreqs = [8.2, 1.5, 2.8, 4.3, 13, 2.2, 2, 6.1, 7, 0.15, 0.77, 4, 2.4, 6.7,
            7.5, 1.9, 0.095, 6, 6.3, 9.1, 2.8, 0.98, 2.4, 0.15, 2, 0.074]

-- Text source: https://wiki.haskell.org/Haiku
ciphertext = "OAAUYDS IA MIFJIAO\nZMUCBSIFHL EOFD-AYXOX SUD VOUL\nPT QC FARE PKCCB\n\nYMCNWYDII CL DOZUYV.\nAOLKS AA IA CNASL EYLCPNO\nRUKRETV CK SISO NZHT!\n\nBIJW JHMMEWY: \"YWE ZGVL!\nERUL FOC BYIBEAD GSREA XI KLNAO -\nLWAHQXE QVUZ LUV JOLO.\"\n\nYEWTG VCKA EZBIJ:\nAHQC YDLGIXN HYOOBUE KIMC\nMG CEZI KMPCSVS\n\nZHSSOFD'Z CZIJLPC NYLE\nPS VKNMYAT DI KVMM PIDRS\nIXX KV IA RUARU'A"


{- Exercise 2 -}
pageOf :: String -> Page
pageOf = undefined

instance ZPM Char where
-- TODO

instance ZPM a => ZPM [a] where
-- TODO

instance ZPM Page where
-- TODO

redact :: String -> Page -> Page
redact = undefined

{- Tests -}
tests = do
    check "freqs1" "[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]" (freqs "")
    check "freqs2" "[2,2,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1]" (freqs "ABCABCZ")
    check "freqs3" "[2,0,0,0,2,1,0,1,1,0,1,4,0,1,0,0,0,1,2,0,1,0,0,0,1,0]" (freqs "HASKELLISREALLYFUN")
    check "indexOfCoincidence1" "0.46666667" (indexOfCoincidence "AAAABB")
    check "indexOfCoincidence2" "5.4347824e-2" (indexOfCoincidence "SOMENONRANDOMENGLISHTEXT")
    check "indexOfCoincidence3" "1.0" (indexOfCoincidence "HHHHHHHHHHH")
    check "nSubstrings1" "[\"\",\"\",\"\",\"\",\"\"]" (nSubstrings "" 5)
    check "nSubstrings2" "[\"HLO\",\"EL\"]" (nSubstrings "HELLO" 2)
    check "nSubstrings3" "[\"HASKELL\"]" (nSubstrings "HASKELL" 1)
    check "nSubstrings4" "[]" (nSubstrings "HASKELL" 0)
    check "nSubstrings5" "[]" (nSubstrings "HASKELL" (-1))
    check "nSubstrings6" "[\"AT\",\"NE\",\"ES\",\"WT\"]" (nSubstrings "ANEWTEST" 4)
    check "rotate1" "\"kellHas\"" (rotate 3 "Haskell" :: String)
    check "rotate2" "\"Haskell\"" (rotate 0 "Haskell" :: String)
    check "rotate3" "\"Haskell\"" (rotate (-1) "Haskell" :: String)
    check "rotate4" "\"Haskell\"" (rotate 11 "Haskell" :: String)
    check "rotate5" "[4,5,1,2,3]" (rotate 3 [1, 2, 3, 4, 5] :: [Int])
    check "chisqrVals1" "[669.6498,2225.9998,3471.8472,6567.0854,1119.9314,135035.56,4900.4194,66567.08,4067.0852,10104.501,3471.8472,999.3201,1487.7205,1567.0857,105163.56,5163.5767,1233.7524,1392.9565,4067.0852,2400.4187,12887.435,66567.07,1328.9905,1539.7634,4900.419,4445.8735]" (chisqrVals "E")
    check "chisqrVals2" "[304.35477,1074.429,2130.72,4389.452,481.68253,903.0427,198.86958,968.2157,1090.9866,3630.6326,3094.483,1148.6993,752.69904,964.89685,1641.88,861.12195,493.76114,1188.4148,2461.0078,2245.233,5126.7695,1894.0508,2903.0837,773.8,1786.727,1207.7972]" (chisqrVals "THISISNORMALTEXT")
    check "chisqrVals3" "[304.35477,1074.429,2130.72,4389.452,481.68253,903.0427,198.86958,968.2157,1090.9866,3630.6326,3094.483,1148.6993,752.69904,964.89685,1641.88,861.12195,493.76114,1188.4148,2461.0078,2245.233,5126.7695,1894.0508,2903.0837,773.8,1786.727,1207.7972]" (chisqrVals "THIS IS NORMAL TEXT")
    check "pageOf" (show $ Page ["This is", "a test."]) (pageOf "This is\na test.")
    checkAll "law1_char" [' '..'~'] show (\c -> c <+> zero)
    checkAll "law2_char" [' '..'~'] show (\c -> zero <+> c)
    checkAll "law3_char" [' '..'~'] show (\c -> c <+> c <-> c)
    checkAll "law1_string" ["hello", "world", "page", ""] show (\s -> s <+> zero)
    checkAll "law2_string" ["hello", "world", "page", ""] show (\s -> zero <+> s)
    checkAll "law3_string" ["hello", "world", "page", ""] show (\s -> s <+> s <-> s)
    checkAll "law3_int_list" [([],[1]),([1,2],[1::Int])]
        (\(x, y) -> show $ max (length x) (length y))
        (\(x, y) -> length $ x <+> y)
    check "law1_page" (show testPage) (testPage <+> zero)
    check "law2_page" (show testPage) (zero <+> testPage)
    check "law3_page" (show testPage) (testPage <+> testPage <-> testPage)
    check "redact1"
      (show $ Page $ "Dear XXXX," : tail (case testPage of Page p -> p))
      (redact "Page" testPage)
    check "redact2" (show $ pageOf "XXXXha") (redact "haha" $ pageOf "hahaha")
    check "redact3" (show $ pageOf "XXXXXXXX") (redact "haha" $ pageOf "hahahaha")
    check "redact4"
      (show $ Page $ take 5 (case testPage of Page p -> p) ++ ["XXXXXXXXX,","X."])
      (redact "Sincerely" testPage)

check name e c = do
    putStr ("*** " ++ name ++ ": ")
    if show c == e then putStrLn "OK"
    else putStrLn ("ERROR; expected '" ++ e ++ "', but found '" ++ show c ++ "'")

checkAll name xs e c = do
    putStr ("*** " ++ name ++ ": ")
    let errors = filter (\x -> show (c x) /= e x) xs
    if null errors then putStrLn "OK"
    else do
        let x = head errors
        putStrLn ("ERROR; expexted '" ++ e x ++ "', but found '" ++ show (c x) ++ "'")