{-# LANGUAGE OverloadedStrings #-}

import Analysis.Confluence.CheckingConfluence (
  analyzeConfluence,
  analyzeConfluenceSeq,
  defaultStrategy,
 )
import Analysis.Confluence.Confluence (
  crresultToAnswer,
  prettyAnswer,
  prettyCRResult,
 )
import Analysis.Confluence.NewmansLemma (
  isLocallyConfluent,
 )
import Control.Concurrent (setNumCapabilities)
import qualified Control.Exception as X
import Control.Monad (
  replicateM,
  unless,
  when,
 )
import Control.Monad.IO.Class (MonadIO (liftIO))
import Data.LCTRS.FIdentifier (FId)
import Data.LCTRS.LCTRS (
  LCTRS,
  checkLhsRules,
  definesToSExpr,
  emptyLCTRS,
  getRules,
  prettyARILCTRS,
  prettyLCTRSDefault,
 )
import Data.LCTRS.Transformation (moveValuesToGuard, unifyRules)
import Data.LCTRS.VIdentifier (VId)
import Data.List (sort)
import Data.Maybe (fromMaybe, isJust)
import Data.Monad
import Data.Version (showVersion)
import Options.Applicative hiding (
  Failure,
  Success,
 )
import Parser.Ari (getLogic)
import qualified Parser.Ari as Ari
import Parser.Strategy (crStrategyExplanation, parseCRStrategy)
import Paths_crest (version)
import qualified Prettyprinter as PP
import Rewriting.CriticalPair (
  computeCPs,
  prettyCriticalPair,
  prettyCriticalPairs,
 )
import Rewriting.ParallelCriticalPair (
  computeParallelCPs,
  prettyParallelCriticalPair,
  prettyParallelCriticalPairs,
 )
import SimpleSMT (forceStop)
import System.Clock (
  Clock (Monotonic),
  diffTimeSpec,
  getTime,
  toNanoSecs,
 )
import qualified System.Timeout as Timeout
import Text.Printf (printf)
import Type.SortLCTRS (DefaultFId, DefaultVId)
import Type.TypeChecking (typeCheck)
import Type.TypeInference (deriveTypes)

data Args = Args
  { filepath :: FilePath
  , timeoutI :: Maybe Int
  , threads
      :: Maybe Int
  , cps :: Bool
  , pcps :: Bool
  , wcr :: Bool
  , verbose :: Bool
  , debugFlag
      :: Bool
  , minimalOut :: Bool
  , z3 :: Bool
  , cvc5 :: Bool
  , yices :: Bool
  , sequential :: Bool
  , strategy :: Maybe String
  , ari :: Bool
  }

main :: IO ()
main = do
  args <- execParser opts
  maybe (return ()) setNumCapabilities $ threads args
  let to = timeoutI args
  let solverTO = fromMaybe 10 $ timeoutI args -- impose timeout for Yices
  let extsolver
        | z3 args = Z3
        | cvc5 args = CVC5
        | yices args = Yices
        | otherwise = Z3
  if cps args
    then do
      (res, lctrs) <- execFreshM 0 $ do
        (lctrs, _) <- parseAndCheckLCTRS args
        s <- liftIO $ chooseSolver solverTO (debugFlag args) extsolver
        fresh <- freshI
        criticalPairs <-
          liftIO $
            evalStateM s (definesToSExpr lctrs) (debugFlag args) fresh $
              computeCPs lctrs
        return (criticalPairs, lctrs)
      print $ prettyLCTRSDefault lctrs
      case res of
        Left (MayError e) -> error e
        Left (SevError e) -> error e
        Right [] -> putStrLn "No constrained critical pairs"
        Right [cp] ->
          print
            ( "1 constrained critical pair found (without symmetries):"
                PP.<> PP.line
                PP.<> prettyCriticalPair cp
            )
        Right cs ->
          print
            ( PP.pretty (length cs)
                PP.<+> "constrained critical pairs found (without symmetries):"
                PP.<> PP.line
                PP.<> prettyCriticalPairs cs
            )
    else
      if pcps args
        then do
          (res, lctrs) <- execFreshM 0 $ do
            (lctrs, _) <- parseAndCheckLCTRS args
            s <- liftIO $ chooseSolver solverTO (debugFlag args) extsolver
            fresh <- freshI
            parallelCriticalPairs <-
              liftIO $
                evalStateM s (definesToSExpr lctrs) (debugFlag args) fresh $
                  computeParallelCPs lctrs
            return (parallelCriticalPairs, lctrs)
          print $ prettyLCTRSDefault lctrs
          case res of
            Left (MayError e) -> error e
            Left (SevError e) -> error e
            Right [] -> putStrLn "No constrained parallel critical pairs"
            Right [cp] ->
              print
                ( "1 constrained parallel critical pair found..."
                    PP.<> PP.line
                    PP.<> prettyParallelCriticalPair cp
                )
            Right cs ->
              print
                ( PP.pretty (length cs)
                    PP.<+> "constrained parallel critical pairs found..."
                    PP.<> PP.line
                    PP.<> prettyParallelCriticalPairs cs
                )
        else
          if wcr args
            then do
              (crresult, lctrs) <- execFreshM 0 $ do
                (lctrs', _) <- parseAndCheckLCTRS args
                lctrs <- moveValuesToGuard lctrs' >>= unifyRules
                s <- liftIO $ chooseSolver solverTO (debugFlag args) extsolver
                fresh <- freshI
                wcrresults <-
                  liftIO $
                    evalStateM s (definesToSExpr lctrs) (debugFlag args) fresh $
                      mapM (isLocallyConfluent lctrs) =<< computeCPs lctrs
                return (wcrresults, lctrs)
              print $ prettyLCTRSDefault lctrs
              case crresult of
                Left (MayError e) -> error e
                Left (SevError e) -> error e
                Right crresults | all isJust crresults -> putStrLn "LCTRS is locally confluent."
                _ -> putStrLn "Cannot determine Local Confluence."
            else
              if ari args
                then execFreshM 0 $ do
                  (lctrs, _) <- parseAndCheckLCTRS args
                  let output = prettyARILCTRS lctrs
                  liftIO $ print output
                else do
                  startTime <- getTime Monotonic
                  let inputStrategy = parseCRStrategy <$> strategy args
                  -- res <- X.handle (\X.SomeException{} -> pure Nothing) $ timeout to $ do
                  let startSolver = chooseSolver solverTO (debugFlag args) extsolver
                  res <- timeoutId to $ do
                    X.bracket
                      ( case inputStrategy of
                          -- Nothing -> replicateM numberOfAllCRMethods startSolver
                          Nothing ->
                            replicateM
                              (length $ defaultStrategy (emptyLCTRS :: LCTRS (FId DefaultFId) (VId DefaultVId)) [])
                              startSolver
                          Just is -> mapM (const startSolver) is
                      )
                      (mapM_ forceStop)
                      $ \solvers -> execFreshM 0 $ do
                        (inputLCTRS, _) <- parseAndCheckLCTRS args
                        when (verbose args) $ do
                          liftIO $
                            putStrLn
                              "##################################################"
                          liftIO $ putStrLn "# Move Values to Guard and unify rules ..."
                          liftIO $
                            putStrLn
                              "##################################################"
                        transformedLCTRS <- moveValuesToGuard inputLCTRS >>= unifyRules
                        when (verbose args) $ do
                          liftIO $
                            putStrLn
                              "##################################################"
                          liftIO $ putStrLn "# Check the given LCTRS for confluence..."
                          liftIO $
                            putStrLn
                              "##################################################"
                        result <-
                          if sequential args
                            then
                              analyzeConfluenceSeq
                                extsolver
                                solvers
                                inputStrategy
                                (debugFlag args)
                                solverTO
                                transformedLCTRS
                            else
                              analyzeConfluence
                                extsolver
                                solvers
                                inputStrategy
                                (debugFlag args)
                                solverTO
                                transformedLCTRS
                        return ((inputLCTRS, transformedLCTRS), result)
                  case res of
                    Nothing -> do
                      -- printf "Timeout after %d seconds.\n" to
                      endTime <- getTime Monotonic
                      let elapsed = toNanoSecs $ diffTimeSpec endTime startTime
                      putStrLn $
                        printf
                          "Timeout after %6.2f ms.\n"
                          (fromIntegral elapsed * 1e-6 :: Double)
                    Just ((inputLCTRS, transformedLCTRS), proof) -> do
                      print $ prettyAnswer $ crresultToAnswer proof
                      unless (sort (getRules inputLCTRS) == sort (getRules transformedLCTRS)) $ do
                        putStrLn "Input LCTRS..."
                        print $ prettyLCTRSDefault inputLCTRS
                        putStrLn "Apply Transformations to the input LCTRS..."
                      unless (minimalOut args) $ do
                        print $ prettyLCTRSDefault transformedLCTRS
                        print $ prettyCRResult proof
                      endTime <- getTime Monotonic
                      let elapsed = toNanoSecs $ diffTimeSpec endTime startTime
                      putStrLn $
                        printf
                          "Elapsed Time: %6.2f ms"
                          (fromIntegral elapsed * 1e-6 :: Double)
 where
  timeoutId Nothing = fmap Just
  timeoutId (Just t) = timeout t

  opts =
    info
      (helper <*> arguments)
      ( fullDesc
          PP.<> progDesc
            ( "(Non-)Confluence analysis executable of the Constrained REwriting Sortware Tool version "
                <> getVersion
                <> "."
            )
          PP.<> header ("Constrained REwriting Sortware Tool version " <> getVersion <> " for analyzing LCTRSs.")
      )

getVersion :: String
getVersion = showVersion version

parseAndCheckLCTRS :: Args -> FreshM (LCTRS (FId DefaultFId) (VId DefaultVId), Logic)
parseAndCheckLCTRS args = do
  let file = filepath args
  when (verbose args) $ do
    liftIO $ putStrLn "##################################################"
    liftIO $ putStrLn "# Parsing the given LCTRS..."
    liftIO $ putStrLn "##################################################"
  input <- Ari.fromFile' file
  -- let counter = Ari.counter input
  -- let lctrs = Ari.getLCTRS input
  -- (counter, input) <- if its args
  --   then do
  --     input <- Its.fromFile' file
  --     return (Its.counter input, Its.inputToLctrs input)
  --   else do
  --     input <- Ari.fromFile' file
  --     return (Ari.counter input, Ari.inputToLctrs input)
  ----------------
  -- when (verbose args) $ do
  --   liftIO $ putStrLn "Success"
  --   liftIO $ putStrLn "##################################################"
  --   liftIO $ putStrLn "# Annotate LCTRS with preliminary types..."
  --   liftIO $ putStrLn "##################################################"
  -- input' <- annotateInput input
  ----------------
  when (verbose args) $ do
    liftIO $ putStrLn "Success"
    liftIO $ putStrLn "##################################################"
    liftIO $ putStrLn "# Infer sorts of variables and function symbols..."
    liftIO $ putStrLn "##################################################"
  let lctrs = deriveTypes input
  when (verbose args) $ do
    liftIO $ putStrLn "Success"
    liftIO $ putStrLn "##################################################"
    liftIO $ putStrLn "# Type checking rules and terms..."
    liftIO $ putStrLn "##################################################"
  if typeCheck (Ari.getFSorts input) (Ari.getThSorts input) lctrs
    then when (verbose args) (liftIO $ putStrLn "Success")
    else error "Type checking failed..."
  when (verbose args) $ do
    liftIO $ putStrLn "##################################################"
    liftIO $ putStrLn "# Checking conditions for a valid LCTRS..."
    liftIO $ putStrLn "##################################################"
  unless (checkLhsRules lctrs) $
    error
      "Left-hand sides of rules must have a term symbol at root position."
  return (lctrs, getLogic input)

arguments :: Parser Args
arguments =
  Args
    <$> argument
      str
      (metavar "FILEPATH" PP.<> help "Path to the source file")
    <*> optional
      ( option
          auto
          ( metavar "TIMEOUT"
              PP.<> long "timeout"
              PP.<> short 't'
              PP.<> help "Timeout of the analysis"
          )
      )
    <*> optional
      ( option
          auto
          ( metavar "THREADS"
              PP.<> long "threads"
              PP.<> short 'j'
              PP.<> help "Number of Threads for the analysis"
          )
      )
    <*> switch
      (long "cps" PP.<> help "Compute the constrained critical pairs")
    <*> switch
      (long "pcps" PP.<> help "Compute the constrained parallel critical pairs")
    <*> switch
      (long "wcr" PP.<> help "Show local confluence (Note: termination is not shown!)")
    <*> switch (long "verb" PP.<> short 'v' PP.<> help "Get more verbose output")
    <*> switch
      (long "debug" PP.<> help "Debug mode which shows also internal errors")
    <*> switch
      (long "result" PP.<> help "Suppress any output except result and time")
    <*> switch (long "z3" PP.<> help "Use Z3 as SMT solver (default)")
    <*> switch
      (long "cvc5" PP.<> help "Use CVC5 as SMT solver (experimental)")
    <*> switch
      (long "yices" PP.<> help "Use Yices as SMT solver (experimental)")
    <*> switch
      (long "seq" PP.<> help "Uses sequential mode of crest without concurrency")
    <*> optional
      ( option
          str
          ( metavar "STRING"
              PP.<> long "strategy"
              PP.<> short 's'
              PP.<> help crStrategyExplanation
          )
      )
    <*> switch
      (long "ari" PP.<> help "Prints the LCTRS fully sorted in the ARI format.")

timeout :: Int -> IO a -> IO (Maybe a)
timeout sec = Timeout.timeout (sec * 1000000)
