{-# LANGUAGE OverloadedStrings #-}

import Analysis.Confluence.Confluence (prettyAnswer)
import Analysis.Termination.CheckingTermination (analyzeTermination, defaultStrategy)
import Analysis.Termination.DependencyGraph (approximateDPGraph, stronglyConnComps)
import Analysis.Termination.Termination (snresultToAnswer)
import Control.Concurrent (setNumCapabilities)
import qualified Control.Exception as X
import Control.Monad (
  replicateM,
  unless,
  when,
 )
import Control.Monad.IO.Class (MonadIO (liftIO))
import Data.LCTRS.FIdentifier (FId)
import Data.LCTRS.LCTRS (
  LCTRS,
  checkLhsRules,
  definesToSExpr,
  emptyLCTRS,
  prettyLCTRSDefault,
 )
import Data.LCTRS.VIdentifier (VId)
import Data.Maybe (fromMaybe)
import Data.Monad (
  Error,
  ExtSolver (CVC5, Yices, Z3),
  FreshM,
  chooseSolver,
  evalComp,
  execFreshM,
 )
import Data.Version (showVersion)
import Options.Applicative hiding (
  Failure,
  Success,
 )
import qualified Parser.Ari as Ari
import Parser.Strategy (parseSNStrategy, snStrategyExplanation)
import Paths_crest (version)
import Prettyprinter (Pretty (pretty), vsep)
import SimpleSMT (forceStop)
import System.Clock (
  Clock (Monotonic),
  diffTimeSpec,
  getTime,
  toNanoSecs,
 )
import qualified System.Timeout as Timeout
import Text.Printf (printf)
import Type.SortLCTRS (DefaultFId, DefaultVId)
import Type.TypeChecking (typeCheck)
import Type.TypeInference (deriveTypes)

data Args = Args
  { filepath :: FilePath
  , timeoutI :: Maybe Int
  , threads :: Maybe Int
  , sccs :: Bool
  , verbose :: Bool
  , debugFlag :: Bool
  , minimalOut :: Bool
  , z3 :: Bool
  , cvc5 :: Bool
  , yices :: Bool
  , -- , sequential :: Bool
    strategy :: Maybe String
  }

main :: IO ()
main = do
  args <- execParser opts
  maybe (return ()) setNumCapabilities $ threads args
  let to = timeoutI args
  let solverTO = fromMaybe 10 $ timeoutI args -- impose timeout for Yices
  let extsolver
        | z3 args = Z3
        | cvc5 args = CVC5
        | yices args = Yices
        | otherwise = Z3
  startTime <- getTime Monotonic
  let inputStrategy = parseSNStrategy <$> strategy args
  -- res <- X.handle (\X.SomeException{} -> pure Nothing) $ timeout to $ do
  -- when (sccs args) $ do
  if sccs args
    then do
      computation <- execFreshM 0 $ do
        solver <- liftIO $ chooseSolver solverTO (debugFlag args) extsolver
        lctrs <- parseAndCheckLCTRS args
        -- lctrs <- moveLogicTermsOverLVarsToGuard lctrs' >>= unifyRules
        evalComp solver (definesToSExpr lctrs) (debugFlag args) $ do
          liftIO $ print $ prettyLCTRSDefault lctrs
          dpgraph <- approximateDPGraph lctrs
          return (dpgraph, stronglyConnComps dpgraph)
      -- computedSCCs <-
      --   evalComp solver (definesToSExpr lctrs) (debugFlag args) $
      --     stronglyConnComps
      --       <$> approximateDPGraph lctrs
      -- return (unlines . map (show . pretty) $ handleErrorGraph computedSCCs)
      case handleErrorGraph computation of
        Nothing -> error "RunTermination.hs: Computation of DP graph failed."
        Just (dpgraph, sts) -> do
          print $ pretty dpgraph
          putStrLn "The following SCCs are computed:"
          print $ vsep $ map pretty sts
    else do
      let startSolver = chooseSolver solverTO (debugFlag args) extsolver
      res <- timeoutId to $ do
        X.bracket
          ( case inputStrategy of
              Nothing ->
                replicateM
                  (length $ defaultStrategy (emptyLCTRS :: LCTRS (FId DefaultFId) (VId DefaultVId)))
                  startSolver
              Just is -> mapM (const startSolver) is
          )
          (mapM_ forceStop)
          $ \solvers -> execFreshM 0 $ do
            lctrs <- parseAndCheckLCTRS args
            when (verbose args) $ do
              liftIO $ putStrLn "##################################################"
              liftIO $ putStrLn "# Check the given LCTRS for termination..."
              liftIO $ putStrLn "##################################################"
            result <-
              analyzeTermination
                solvers
                inputStrategy
                (debugFlag args)
                lctrs
            return (lctrs, result)
      case res of
        Nothing -> do
          -- printf "Timeout after %d seconds.\n" to
          endTime <- getTime Monotonic
          let elapsed = toNanoSecs $ diffTimeSpec endTime startTime
          putStrLn $
            printf
              "Timeout after %6.2f ms.\n"
              (fromIntegral elapsed * 1e-6 :: Double)
        Just (lctrs, proof) -> do
          print $ prettyAnswer $ snresultToAnswer proof
          unless (minimalOut args) $ do
            -- putStrLn $ toAriString Nothing lctrs
            print $ prettyLCTRSDefault lctrs
            print . pretty . snd $ proof
          endTime <- getTime Monotonic
          let elapsed = toNanoSecs $ diffTimeSpec endTime startTime
          putStrLn $
            printf
              "Elapsed Time: %6.2f ms"
              (fromIntegral elapsed * 1e-6 :: Double)
 where
  timeoutId Nothing = fmap Just
  timeoutId (Just t) = timeout t

  opts =
    info
      (helper <*> arguments)
      ( fullDesc
          <> progDesc
            ( "(Non-)Termination analysis executable of the Constrained REwriting Sortware Tool version "
                <> getVersion
                <> "."
            )
          <> header ("Constrained REwriting Sortware Tool version " <> getVersion <> " for analyzing LCTRSs.")
      )

getVersion :: String
getVersion = showVersion version

parseAndCheckLCTRS :: Args -> FreshM (LCTRS (FId DefaultFId) (VId DefaultVId))
parseAndCheckLCTRS args = do
  let file = filepath args
  when (verbose args) $ do
    liftIO $ putStrLn "##################################################"
    liftIO $ putStrLn "# Parsing the given LCTRS..."
    liftIO $ putStrLn "##################################################"
  input <- Ari.fromFile' file
  -- let counter = Ari.counter input
  -- let lctrs = Ari.getLCTRS input
  -- (counter, input) <- if its args
  --   then do
  --     input <- Its.fromFile' file
  --     return (Its.counter input, Its.inputToLctrs input)
  --   else do
  --     input <- Ari.fromFile' file
  --     return (Ari.counter input, Ari.inputToLctrs input)
  ----------------
  -- when (verbose args) $ do
  --   liftIO $ putStrLn "Success"
  --   liftIO $ putStrLn "##################################################"
  --   liftIO $ putStrLn "# Annotate LCTRS with preliminary types..."
  --   liftIO $ putStrLn "##################################################"
  -- input' <- annotateInput input
  ----------------
  when (verbose args) $ do
    liftIO $ putStrLn "Success"
    liftIO $ putStrLn "##################################################"
    liftIO $ putStrLn "# Infer sorts of variables and function symbols..."
    liftIO $ putStrLn "##################################################"
  let lctrs = deriveTypes input
  when (verbose args) $ do
    liftIO $ putStrLn "Success"
    liftIO $ putStrLn "##################################################"
    liftIO $ putStrLn "# Type checking rules and terms..."
    liftIO $ putStrLn "##################################################"
  if typeCheck (Ari.getFSorts input) (Ari.getThSorts input) lctrs
    then when (verbose args) (liftIO $ putStrLn "Success")
    else error "Type checking failed..."
  when (verbose args) $ do
    liftIO $ putStrLn "##################################################"
    liftIO $ putStrLn "# Checking conditions for a valid LCTRS..."
    liftIO $ putStrLn "##################################################"
  unless (checkLhsRules lctrs) $
    error
      "Left-hand sides of rules must have a term symbol at root position."
  return lctrs

arguments :: Parser Args
arguments =
  Args
    <$> argument str (metavar "FILEPATH" <> help "Path to the source file")
    <*> optional
      ( option
          auto
          ( metavar "TIMEOUT"
              <> long "timeout"
              <> short 't'
              <> help
                "Timeout of the analysis"
          )
      )
    <*> optional
      ( option
          auto
          ( metavar "THREADS"
              <> long "threads"
              <> short 'j'
              <> help
                "Number of Threads for the analysis"
          )
      )
    <*> switch (long "sccs" <> help "Print the strongly connected components of the Dependency Graph")
    <*> switch (long "verb" <> short 'v' <> help "Get more verbose output")
    <*> switch
      (long "debug" <> help "Debug mode which shows also internal errors")
    <*> switch
      (long "result" <> help "Suppress any output except result and time")
    <*> switch (long "z3" <> help "Use Z3 as SMT solver (default)")
    <*> switch (long "cvc5" <> help "Use CVC5 as SMT solver")
    <*> switch (long "yices" <> help "Use Yices as SMT solver")
    -- <*> switch
    --   ( long "seq"
    --       <> help "Uses sequential mode of crest without concurrency"
    --   )
    <*> optional
      ( option
          str
          ( metavar "STRING"
              <> long "strategy"
              <> short 's'
              <> help snStrategyExplanation
          )
      )

timeout :: Int -> IO a -> IO (Maybe a)
timeout sec = Timeout.timeout (sec * 1000000)

handleErrorGraph :: Either Error a -> Maybe a
handleErrorGraph (Left _) = Nothing
handleErrorGraph (Right r) = Just r
