{- |
Module      : Utils
Description :
Copyright   : (c) Jonas Schöpf, 2023
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides various utility and auxiliary which are used in the whole
library.
-}
module Utils (
  -- * Functions on Lists
  findM,
  findMs,
  findMsM,
  lenSensitiveZip,
  takeM,
  parMap,
  parMapM,
  mapOnMaybes,
  dropNthElem,
  dropCommonPrefix,
  shuffle,

  -- * Functions on Maps
  monadicUnionWith,
  monadicUnionsWith,

  -- * Asynchronous Functions
  asyncs,
  waitF,

  -- * Parser for Lists
  parseList,
  parseIntList,

  -- * Operations on Tuples
  fst3,
  snd3,
  thi3,

  -- * URL Encoding Function
  urlEncode,
) where

----------------------------------------------------------------------------------------------------
-- imorts
----------------------------------------------------------------------------------------------------

import Control.Arrow (first)
import Control.Concurrent.Async (
  Async,
  uninterruptibleCancel,
  waitAny,
 )
import Control.Monad (forM, join, liftM2)
import Control.Monad.Combinators (
  between,
  empty,
  sepBy,
  (<|>),
 )
import qualified Control.Monad.Parallel as MPar
import qualified Control.Parallel.Strategies as Par
import Data.Array.IO (
  IOArray,
  newListArray,
  readArray,
  writeArray,
 )
import Data.Bits (shiftR, (.&.))
import Data.Char (isAlphaNum, isAscii, ord)
import qualified Data.Foldable as DF
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Void (Void)
import Data.Word (Word8)
import System.Random (randomRIO)
import Text.Megaparsec (
  Parsec,
  errorBundlePretty,
  parse,
 )
import Text.Megaparsec.Char (space1)
import qualified Text.Megaparsec.Char.Lexer as L
import UnliftIO.Async (async)

----------------------------------------------------------------------------------------------------
-- list functions
----------------------------------------------------------------------------------------------------

-- | 'findM' @p ms@ searches in the list @ms@ of values the leftmost element satisfying the monadic cpredicate @p@.
findM :: (Monad m) => (a -> m Bool) -> [a] -> m (Maybe a)
findM _ [] = return Nothing
findM predicate (elem : es) = do
  found <- predicate elem
  if found then return $ Just elem else findM predicate es

-- | 'findMs' @p ms@ searches in the list @ms@ of monadic values the leftmost element satisfying the predicate @p@.
findMs :: (Monad m) => (a -> Bool) -> [m a] -> m (Maybe a)
findMs _ [] = return Nothing
findMs predicate (x : xs) = do
  val <- x
  if predicate val then return $ Just val else findMs predicate xs

-- | 'findM' @p ms@ searches in the list @ms@ of monadic values the leftmost element satisfying the monadic predicate @p@.
findMsM :: (Monad m) => (a -> m Bool) -> [m a] -> m (Maybe a)
findMsM _ [] = return Nothing
findMsM predicate (m : ms) = do
  val <- m
  found <- predicate val
  if found then return $ Just val else findMsM predicate ms

-- | 'lenSensitiveZip' @xs ys@ performs 'zip' @xs ys@ only if the length of @xs@ matches the length of @ys@
lenSensitiveZip :: [a] -> [b] -> [(a, b)]
lenSensitiveZip xs ys
  | length xs == length ys = zip xs ys
  | otherwise = error "lenSensitiveZip: lengths of lists do not match."

-- | 'takeM' @n ms@ takes the first @n@ elements of the list of monadic values @ms@ and returns those after evaluation.
takeM :: (Monad m) => Int -> [m a] -> m [a]
takeM 0 _ = return []
takeM _ [] = return []
takeM n (me : rest) = do
  val <- me
  (val :) <$> takeM (pred n) rest

-- | 'parMap' @i f xs@ applies function @f@ in parallel on all elements in @xs@ using chunks of size @i@
parMap :: Int -> (a -> b) -> [a] -> [b]
parMap chunkSize f xs =
  Par.withStrategy (Par.parListChunk chunkSize Par.rseq) $ map f xs

-- | 'parMapM' @f xs@ applies function @f@ in parallel on all monadic elements in @xs@
parMapM :: (MPar.MonadParallel m) => (a -> m b) -> [a] -> m [b]
parMapM f xs = MPar.mapM f xs

-- | 'mapOnMaybes' @f ms@ applies function @f@ on all Just values in @ms@
mapOnMaybes :: (a -> b) -> [Maybe a] -> [b]
mapOnMaybes _ [] = []
mapOnMaybes f (Nothing : ms) = mapOnMaybes f ms
mapOnMaybes f (Just e : ms) = f e : mapOnMaybes f ms

-- | 'dropNthElem' @dropNthElem n xs@ drops the element at position @n@ (starting with 0) from @xs@ and returns the the prefix and suffix list.
dropNthElem :: Int -> [a] -> ([a], [a])
dropNthElem n xs | n < 0 = ([], xs)
dropNthElem n xs | n >= length xs = (xs, [])
dropNthElem n xs = go 0 xs
 where
  go _ [] = ([], [])
  go i (_ : xs) | i == n = ([], xs)
  go i (x : xs) = first (x :) $ go (succ i) xs

-- | @dropCommonPrefix xs ys@ removes the common prefix of @xs@ and @ys@ and returns the remaining lists as a pair.
dropCommonPrefix :: (Eq a) => [a] -> [a] -> ([a], [a])
dropCommonPrefix (x : xs) (y : ys) | x == y = dropCommonPrefix xs ys
dropCommonPrefix xs ys = (xs, ys)

-- taken from https://wiki.haskell.org/Random_shuffle

{- | Randomly shuffle a list
  /O(N)/
-}
shuffle :: [a] -> IO [a]
shuffle xs = do
  ar <- newArray n xs
  forM [1 .. n] $ \i -> do
    j <- randomRIO (i, n)
    vi <- readArray ar i
    vj <- readArray ar j
    writeArray ar j vi
    return vj
 where
  n = length xs
  newArray :: Int -> [a] -> IO (IOArray Int a)
  newArray n xs = newListArray (1, n) xs

----------------------------------------------------------------------------------------------------
-- Map functions
----------------------------------------------------------------------------------------------------

monadicUnionWith :: (Monad m, Ord k) => (a -> a -> m a) -> M.Map k a -> M.Map k a -> m (M.Map k a)
monadicUnionWith f mapA mapB =
  sequence $ M.unionWith (\a b -> do x <- a; y <- b; f x y) (M.map return mapA) (M.map return mapB)

monadicUnionsWith :: (Monad m, Ord k) => (a -> a -> m a) -> [M.Map k a] -> m (M.Map k a)
monadicUnionsWith f maps =
  sequence $ M.unionsWith ((join .) . liftM2 f) (map (M.map return) maps)

----------------------------------------------------------------------------------------------------
-- async functions
----------------------------------------------------------------------------------------------------

-- | 'asyncs' @ios@ maps the 'async' function over all 'IO' values.
asyncs :: [IO a] -> IO [Async a]
asyncs = mapM async

{- | 'waitF' @p default jobs@ executes the set @jobs@ asynchronously until one evaluates to a value
satisfying @p@. If none satisfy @p@ the default value @default@ is returned.
-}
waitF :: (a -> Bool) -> a -> S.Set (Async a) -> IO a
waitF check def jobs
  | null jobs = return def
  | otherwise = do
      (task, res) <- waitAny (DF.toList jobs)
      let rest = S.delete task jobs
      if check res
        then res <$ mapM_ uninterruptibleCancel rest
        else waitF check def rest

----------------------------------------------------------------------------------------------------
-- parser
----------------------------------------------------------------------------------------------------

-- | 'parseList' @p@ parses a list of elements given that a element is parsed by @p@.
parseList :: Parsec Void String [a] -> Parsec Void String [a]
parseList = between (symbol "[") (symbol "]")
 where
  symbol = L.symbol whitespace
  whitespace = L.space space1 empty empty

-- | 'parseIntList' @s@ parses a string @s@ representing a list of integers.
parseIntList :: String -> [Int]
parseIntList s = case parse pIntList "" s of
  Left bundle -> error (errorBundlePretty bundle)
  Right r -> r
 where
  pIntList = parseList $ sepBy integer (symbol ",")
  lexeme = L.lexeme whitespace
  symbol = L.symbol whitespace
  whitespace = L.space space1 empty empty
  integer =
    lexeme (symbol "-" >> ((*) (-1) <$> L.decimal)) <|> lexeme L.decimal

----------------------------------------------------------------------------------------------------
-- tuples
----------------------------------------------------------------------------------------------------

-- | 'fst3' @tri@ returns the first element of the triple @tri@.
fst3 :: (a, b, c) -> a
fst3 (a, _, _) = a

-- | 'snd3' @tri@ returns the second element of the triple @tri@.
snd3 :: (a, b, c) -> b
snd3 (_, b, _) = b

-- | 'thi3' @tri@ returns the third element of the triple @tri@.
thi3 :: (a, b, c) -> c
thi3 (_, _, c) = c

----------------------------------------------------------------------------------------------------
-- URL encode benchmark paths
----------------------------------------------------------------------------------------------------

-- -- | 'encodeURL' @string@ returns a string with escaped URL symbols.
-- encodeURL :: String -> String
-- encodeURL = concatMap encode
--  where
--   encode :: Char -> String
--   encode c
--     | c == ' ' = "+"
--     | isAlphaNum c || c `elem` ("-._~" :: String) = [c]
--     | otherwise = printf "%%%02X" c

{- | 'urlEncode' @string@ returns a string with escaped URL symbols.
Taken from https://hackage.haskell.org/package/HTTP-4000.4.1/docs/Network-HTTP-Base.html#v:urlEncode.
-}
urlEncode :: String -> String
urlEncode [] = []
urlEncode (ch : t)
  | (isAscii ch && isAlphaNum ch) || ch `elem` ("-_.~" :: String) = ch : urlEncode t
  | not (isAscii ch) = foldr escape (urlEncode t) (encodeChar ch)
  | otherwise = escape (fromIntegral (fromEnum ch)) (urlEncode t)
 where
  escape b rs = '%' : showH (b `div` 16) (showH (b `mod` 16) rs)

  showH :: Word8 -> String -> String
  showH x xs
    | x <= 9 = to (o_0 + x) : xs
    | otherwise = to (o_A + (x - 10)) : xs
   where
    to = toEnum . fromIntegral
    fro = fromIntegral . fromEnum

    o_0 = fro '0'
    o_A = fro 'A'

{- | Encode a single Haskell Char to a list of Word8 values, in UTF8 format.
Taken from utf-8string-0.3.7.
-}
encodeChar :: Char -> [Word8]
encodeChar = map fromIntegral . go . ord
 where
  go oc
    | oc <= 0x7f = [oc]
    | oc <= 0x7ff =
        [ 0xc0 + (oc `shiftR` 6)
        , 0x80 + oc .&. 0x3f
        ]
    | oc <= 0xffff =
        [ 0xe0 + (oc `shiftR` 12)
        , 0x80 + ((oc `shiftR` 6) .&. 0x3f)
        , 0x80 + oc .&. 0x3f
        ]
    | otherwise =
        [ 0xf0 + (oc `shiftR` 18)
        , 0x80 + ((oc `shiftR` 12) .&. 0x3f)
        , 0x80 + ((oc `shiftR` 6) .&. 0x3f)
        , 0x80 + oc .&. 0x3f
        ]
