{-# LANGUAGE ScopedTypeVariables #-}

import Analysis.Confluence.CheckingConfluence (
  analyzeConfluence,
  analyzeConfluenceSeq,
  defaultStrategy,
 )
import Analysis.Confluence.Confluence (
  Answer (..),
  CRMethod (..),
  CRResult (..),
  NCRMethod (..),
  crresultToAnswer,
  prettyAnswer,
  prettyCRResult,
  showCRresultMethod,
 )
import qualified Control.Exception as X
import Control.Monad (
  replicateM,
  when,
 )
import Control.Monad.IO.Class (MonadIO (liftIO))
import Data.LCTRS.FIdentifier (FId)
import Data.LCTRS.LCTRS (
  LCTRS,
  checkLhsRules,
  emptyLCTRS,
  prettyLCTRSDefault,
 )
import Data.LCTRS.Transformation (moveValuesToGuard, unifyRules)
import Data.LCTRS.VIdentifier (VId)
import Data.List (sort)
import Data.Maybe (
  fromMaybe,
 )
import Data.Monad
import Data.Time (defaultTimeLocale, formatTime, getCurrentTime)
import Data.Version (showVersion)
import GHC.Conc (setNumCapabilities)
import Options.Applicative hiding (
  Failure,
  Success,
 )
import qualified Parser.Ari as Ari
import Parser.Strategy (crStrategyExplanation, parseCRStrategy)
import Paths_crest (version)
import SimpleSMT (
  forceStop,
 )
import System.Clock (
  Clock (Monotonic),
  diffTimeSpec,
  getTime,
  toNanoSecs,
 )
import System.FilePath (replaceBaseName, takeBaseName)
import System.Process (
  readCreateProcess,
  shell,
 )
import qualified System.Timeout as Timeout
import Text.Printf (printf)
import Type.SortLCTRS (DefaultFId, DefaultVId)
import Type.TypeChecking (typeCheck)
import Type.TypeInference (deriveTypes)
import Utils (urlEncode)

data Args = Args
  { benchpath :: FilePath
  , timeoutI :: Maybe Int
  , threads :: Maybe Int
  , z3 :: Bool
  , cvc5 :: Bool
  , yices :: Bool
  , sequential :: Bool
  , strategy :: Maybe String
  , html :: Bool
  }

main :: IO ()
main = do
  args <- execParser opts
  maybe (return ()) setNumCapabilities $ threads args
  let benchs = benchpath args
  fps <-
    sort
      . lines
      <$> readCreateProcess
        (shell $ "find " ++ benchs ++ " -type f -name '*.ari' | sort")
        mempty
  let extsolver
        | z3 args = Z3
        | cvc5 args = CVC5
        | yices args = Yices
        | otherwise = Z3
  results <-
    mapM (\(i, fp) -> analyze extsolver args i fp) $
      zip [0 ..] fps
  when (html args) $ produceHTML $ zip fps results
  printf
    "Benchmark Files in \"%s\":\n  Files: %d\n  Yes: %d\n  No: %d\n  Maybe: %d\n  Timeout: %d\n  Error:  %d\n  AVG Time: %6.2f s\n  Total Time: %6.2f s\n  Total Time (Yes): %6.2f s\n  Total Time (No): %6.2f s\n  Total Time (Maybe): %6.2f s\n  Orthogonal: %d\n  Weakly Orthogonal: %d\n  Strongly Closed: %d\n  Parallel Closed: %d\n  Almost Parallel Closed: %d\n  Development Closed: %d\n  Almost Development Closed: %d\n  Two Different Normal Forms Found: %d\n  Newman's Lemma: %d\n  Toyama 81: %d"
    (benchpath args)
    (length fps)
    (length $ filter (onlyResults Yes) results)
    (length $ filter (onlyResults No) results)
    (length $ filter (onlyResults Maybe) results)
    (length $ filter isTimeout results)
    (length $ filter isError results)
    ( milliToSec $
        (/ fromIntegral (length results)) . sum $
          map
            extractTime
            results
    )
    (milliToSec $ sum $ map extractTime results)
    (milliToSec $ sum $ map extractTime $ filter isYes results)
    (milliToSec $ sum $ map extractTime $ filter isNo results)
    (milliToSec $ sum $ map extractTime $ filter isMaybe results)
    (length $ filter (onlyCRMethod Orthogonality) results)
    (length $ filter (onlyCRMethod WeakOrthogonality) results)
    (length $ filter (onlyCRMethod StrongClosedness) results)
    (length $ filter (onlyCRMethod ParallelClosedness) results)
    (length $ filter (onlyCRMethod AlmostParallelClosedness) results)
    (length $ filter (onlyCRMethod DevelopmentClosedness) results)
    (length $ filter (onlyCRMethod AlmostDevelopmentClosedness) results)
    (length $ filter (onlyNCRMethod TwoDifferentNFs) results)
    (length $ filter (onlyCRMethod NewmansLemma) results)
    (length $ filter (onlyCRMethod Toyama81) results)
 where
  opts =
    info
      (helper <*> arguments)
      ( fullDesc
          <> progDesc
            ( "(Non-)Confluence analysis benchmarking executable of the Constrained REwriting Sortware Tool version "
                <> getVersion
                <> ". Searches for files with \".ari\" file extension in the given directory."
            )
          <> header ("Constrained REwriting Sortware Tool version " <> getVersion <> " for analyzing LCTRSs.")
      )

  milliToSec = (/ 1000)

getVersion :: String
getVersion = showVersion version

type Time = Double
data BenchResult
  = Result (Answer, CRResult, Time, String)
  | Timeout Time
  | Error Time (Either X.SomeException String)

instance Show BenchResult where
  show (Timeout t) = printf "Timeout (%9.2f ms)" t
  show (Error t _) = printf "  Error (%9.2f ms)" t
  show (Result (a@Yes, _, t, _)) = printf "    %s (%9.2f ms)" (show $ prettyAnswer a) t
  show (Result (a@No, _, t, _)) = printf "     %s (%9.2f ms)" (show $ prettyAnswer a) t
  show (Result (a@Maybe, _, t, _)) = printf "  %s (%9.2f ms)" (show $ prettyAnswer a) t

-- show (Result (a, _, t, _)) = printf "%s (%6.2f ms)" (show $ prettyAnswer a) t

onlyCRMethod :: CRMethod -> BenchResult -> Bool
onlyCRMethod method (Result (_, Confluent (m, _), _, _)) = m == method
onlyCRMethod _ _ = False

onlyNCRMethod :: NCRMethod -> BenchResult -> Bool
onlyNCRMethod method (Result (_, NonConfluent (m, _), _, _)) = m == method
onlyNCRMethod _ _ = False

onlyResults :: Answer -> BenchResult -> Bool
onlyResults goal (Result (a, _, _, _)) = a == goal
onlyResults _ _ = False

isTimeout :: BenchResult -> Bool
isTimeout (Timeout _) = True
isTimeout _ = False

isError :: BenchResult -> Bool
isError (Error _ _) = True
isError _ = False

isYes :: BenchResult -> Bool
isYes (Result (_, Confluent (_, _), _, _)) = True
isYes _ = False

isNo :: BenchResult -> Bool
isNo (Result (_, NonConfluent (_, _), _, _)) = True
isNo _ = False

isMaybe :: BenchResult -> Bool
isMaybe (Result (_, MaybeConfluent, _, _)) = True
isMaybe _ = False

extractTime :: BenchResult -> Time
extractTime (Result (_, _, t, _)) = t
extractTime (Timeout t) = t
extractTime (Error t _) = t

analyze :: ExtSolver -> Args -> Int -> FilePath -> IO BenchResult
analyze extsolver args i file = do
  let
    to = timeoutI args
    timeoutId Nothing = fmap Just
    timeoutId (Just t) = timeout t
    solverTO = fromMaybe 10 $ timeoutI args -- impose timeout for Yices
    inputStrategy = parseCRStrategy <$> strategy args
    startSolver = chooseSolver solverTO False extsolver
  startTime <- getTime Monotonic
  ret <- X.handle (\(e :: X.SomeException) -> return $ Just $ Error 0 (Left e)) $ timeoutId to $ do
    X.bracket
      ( case inputStrategy of
          -- Nothing -> replicateM numberOfAllCRMethods startSolver
          Nothing ->
            replicateM
              (length $ defaultStrategy (emptyLCTRS :: LCTRS (FId DefaultFId) (VId DefaultVId)) [])
              startSolver
          Just is -> mapM (const startSolver) is
      )
      (mapM_ forceStop)
      $ \solvers -> execFreshM 0 $ do
        input <- Ari.fromFile' file
        let lctrs = deriveTypes input
        lctrs' <- moveValuesToGuard lctrs >>= unifyRules
        if typeCheck (Ari.getFSorts input) (Ari.getThSorts input) lctrs' && checkLhsRules lctrs'
          then do
            result <-
              if sequential args
                then
                  analyzeConfluenceSeq
                    extsolver
                    solvers
                    inputStrategy
                    False
                    solverTO
                    lctrs'
                else
                  analyzeConfluence
                    extsolver
                    solvers
                    inputStrategy
                    False
                    solverTO
                    lctrs'
            endTime <- liftIO $ getTime Monotonic
            let elapsed = toNanoSecs $ diffTimeSpec endTime startTime
            let time = (fromIntegral elapsed * 1e-6 :: Double)
            return $ Result (crresultToAnswer result, result, time, show $ prettyLCTRSDefault lctrs')
          else do
            endTime <- liftIO $ getTime Monotonic
            let elapsed = toNanoSecs $ diffTimeSpec endTime startTime
            let time = (fromIntegral elapsed * 1e-6 :: Double)
            return $ Error time $ Right "Does not fullfill LCTRS properties."
  case ret of
    Nothing -> do
      endTime <- getTime Monotonic
      let elapsed = toNanoSecs $ diffTimeSpec endTime startTime
      let time = (fromIntegral elapsed * 1e-6 :: Double)
      let resu = Timeout time
      let output = show i ++ ": " ++ show resu ++ "  " ++ replicate (length $ show i) ' ' ++ file
      when (html args) $
        writeFile (file ++ ".proof") (printf "Timeout after %6.2f ms.\n" time)
      putStrLn output
      return resu
    Just result -> do
      let output = show i ++ ": " ++ show result ++ "  " ++ replicate (length $ show i) ' ' ++ file
      when (html args) $ do
        let filename = file ++ ".proof"
        case result of
          (Timeout t) -> writeFile filename (printf "Timeout after %6.2f ms.\n" t)
          (Error t (Left e)) -> writeFile filename (printf "Error after %6.2f ms.\nException: %s" t (show e))
          (Error t (Right s)) -> writeFile filename (printf "Error after %6.2f ms.\nOther Error: %s" t s)
          (Result (res, proof, t, lctrs)) -> do
            let proofOut =
                  show (prettyAnswer res)
                    ++ "\n"
                    ++ lctrs
                    ++ "\n"
                    ++ show (prettyCRResult proof)
                    ++ "\n\n"
                    ++ printf "Elapsed Time: %6.2f ms" t
            writeFile filename proofOut
      putStrLn output
      return result

-- return $ Result (crresultToAnswer proof, proof, time)

produceHTML :: [(String, BenchResult)] -> IO ()
produceHTML results = do
  timestamp <- formatTime defaultTimeLocale "%Y-%m-%d—%H:%M:%S" <$> getCurrentTime
  let htmlOut =
        wrapTR (wrapTH "Problem" ++ wrapTH "Result" ++ wrapTH "Method" ++ wrapTH "Time (in ms)")
          ++ concatMap go results
  let pureResults = map snd results
  let yeses = length $ filter (onlyResults Yes) pureResults
  let nos = length $ filter (onlyResults No) pureResults
  let maybes = length $ filter (onlyResults Maybe) pureResults
  let timeouts = length $ filter isTimeout pureResults
  let errors = length $ filter isError pureResults
  -- let overallTime = (/ 1000) $ sum $ map extractTime pureResults
  let overallTime = sum $ map extractTime pureResults
  let summary =
        wrapTR
          ( wrapTH ("Summary of the " ++ show (length pureResults) ++ " problems: ")
              ++ wrapTHClass "summaryans" ("#YES: " ++ show yeses)
              ++ wrapTH ""
              ++ wrapTHClass "time" (printf "%6.2f" overallTime)
          )
          ++ wrapTR (wrapTH "" ++ wrapTHClass "summaryans" ("#NO: " ++ show nos) ++ wrapTH "" ++ wrapTH "")
          ++ wrapTR (wrapTH "" ++ wrapTHClass "summaryans" ("#MAYBE: " ++ show maybes) ++ wrapTH "" ++ wrapTH "")
          ++ wrapTR
            (wrapTH "" ++ wrapTHClass "summaryans" ("#Timeout: " ++ show timeouts) ++ wrapTH "" ++ wrapTH "")
          ++ wrapTR (wrapTH "" ++ wrapTHClass "summaryans" ("#Error: " ++ show errors) ++ wrapTH "" ++ wrapTH "")
  writeFile ("crest-cr-benchmark_" ++ timestamp ++ ".html") $ prefix ++ htmlOut ++ summary ++ suffix
 where
  go (problemPath, r@(Timeout t)) =
    wrapTR $
      wrapTD (problem2HTML problemPath)
        ++ wrapTDAnswer r (problem2Proof problemPath "Timeout")
        ++ wrapTDClass "method" "---"
        ++ wrapTDClass "time" (printf "%6.2f" t)
  go (problemPath, r@(Error t _)) =
    wrapTR $
      wrapTD (problem2HTML problemPath)
        ++ wrapTDAnswer r (problem2Proof problemPath "Error")
        ++ wrapTDClass "method" "---"
        ++ wrapTDClass "time" (printf "%6.2f" t)
  go (problemPath, r@(Result (ans, res, t, _))) =
    wrapTR $
      wrapTD (problem2HTML problemPath)
        ++ wrapTDAnswer r (problem2Proof problemPath (show $ prettyAnswer ans))
        ++ wrapTDClass "method" (showCRresultMethod res)
        ++ wrapTDClass "time" (printf "%6.2f" t)

  problem2HTML path =
    printf
      "<a href=\"%s\">%s</a>"
      (replaceBaseName path (urlEncode $ takeBaseName path))
      (takeBaseName path)
  problem2Proof path res = printf "<a href=\"%s.proof\">%s</a>" (replaceBaseName path (urlEncode $ takeBaseName path)) res

  prefix =
    "<!DOCTYPE html><html><head><style>#results {width:100%}\n#results td {border: 1px solid black;}\n#results .summaryans {text-align: left;}\n#results .time {text-align: right;}\n#results .result {text-align: center;}\n#results .method {text-align: center;}</style></head><body><table id=\"results\">"
  suffix = "</table></body></html>"

  wrapTR s = "<tr>" ++ s ++ "</tr>"
  wrapTH s = "<th>" ++ s ++ "</th>"
  wrapTD s = "<td>" ++ s ++ "</td>"
  wrapTDClass c s = "<td class=\"" ++ c ++ "\">" ++ s ++ "</td>"
  wrapTHClass c s = "<th class=\"" ++ c ++ "\">" ++ s ++ "</th>"

  wrapTDAnswer (Timeout _) s = "<td class=\"result\" style=\"background-color:Gray;\">" ++ s ++ "</td>"
  wrapTDAnswer (Error _ _) s = "<td class=\"result\" style=\"background-color:Gray;\">" ++ s ++ "</td>"
  wrapTDAnswer (Result (_, res, _, _)) s
    | crresultToAnswer res == No =
        "<td class=\"result\" style=\"background-color:Tomato;\">" ++ s ++ "</td>"
    | crresultToAnswer res == Yes =
        "<td class=\"result\" style=\"background-color:MediumSeaGreen;\">" ++ s ++ "</td>"
    | crresultToAnswer res == Maybe =
        "<td class=\"result\" style=\"background-color:Orange;\">" ++ s ++ "</td>"
    | otherwise = error "Cannot wrap answer into class for HTML representation, this should not happen."

arguments :: Parser Args
arguments =
  Args
    <$> argument
      str
      ( metavar "BENCHMARKPATH"
          <> help "Path to directory containing benchmarks"
      )
    <*> optional
      ( option
          auto
          ( metavar "TIMEOUT"
              <> long "timeout"
              <> short 't'
              <> help
                "Timeout of the analysis"
          )
      )
    <*> optional
      ( option
          auto
          ( metavar "THREADS"
              <> long "threads"
              <> short 'j'
              <> help
                "Number of threads for the analysis"
          )
      )
    <*> switch (long "z3" <> help "Use Z3 (default)")
    <*> switch (long "cvc5" <> help "Use CVC5")
    <*> switch (long "yices" <> help "Use Yices")
    <*> switch
      ( long "seq"
          <> help
            "Uses the sequential mode of crest without any concurrency."
      )
    <*> optional
      ( option
          str
          ( metavar "STRING"
              <> long "strategy"
              <> short 's'
              <> help crStrategyExplanation
          )
      )
    <*> switch (long "html" <> help "Print proofs to files and produce a \"crest-benchmark.html\".")

timeout :: Int -> IO a -> IO (Maybe a)
timeout sec = Timeout.timeout (sec * 1000000)
