module Analysis.Termination.DependencyGraph where

import Analysis.Termination.DependencyPairs (DPProblem (..), computeDPs)
import Control.Monad (filterM, zipWithM)
import Data.Array ((!))
import Data.Bifunctor (Bifunctor (bimap))
import Data.Graph (Graph, Tree (Node), Vertex, graphFromEdges, scc, vertices)
import qualified Data.IntMap as IM
import Data.LCTRS.FIdentifier
import Data.LCTRS.Guard (collapseGuardToTerm)
import Data.LCTRS.LCTRS (LCTRS, definedSyms)
import Data.LCTRS.Rule (Rule, guard, lhs, lvar, prettyRule, rhs)
import qualified Data.LCTRS.Rule as R
import Data.LCTRS.Sort (Sorted (sort))
import Data.LCTRS.Term
import qualified Data.LCTRS.Term as T
import Data.LCTRS.VIdentifier
import Data.Maybe (mapMaybe)
import Data.Monad (StateM)
import Data.SExpr (ToSExpr (..))
import qualified Data.Set as S
import Prettyprinter (Pretty (pretty), encloseSep, indent, vsep, (<+>))
import Rewriting.SMT (satSExpr)
import qualified SimpleSMT as S
import Utils (fst3)

data DPGraph f v = DPGraph
  { graph :: Graph
  , nodeFromVertex :: Vertex -> (Rule f v, Int, [Int])
  , vertexFromKey :: Int -> Maybe Vertex
  , mapping :: IM.IntMap (Rule f v)
  }

instance (Pretty f, Pretty v) => Pretty (DPGraph f v) where
  pretty dpgraph =
    vsep
      [ "DPGraph with indexed dependency pairs"
      , indent 2 $
          encloseSep "{" "}" ", " $
            map (\(i, r) -> pretty i <> ":" <+> prettyRule r) $
              IM.toList $
                mapping dpgraph
      , "and edges"
      , let nV = nodeFromVertex dpgraph
        in  indent 2
              $ vsep
              $ map
                ((\(_, i, is) -> pretty i <+> "->" <+> encloseSep "{" "}" ", " (map pretty is)) . nV)
              $ vertices
              $ graph dpgraph
      ]

newtype SCC f v = SCC [Rule f v]

instance (Pretty f, Pretty v) => Pretty (SCC f v) where
  pretty (SCC rules) =
    vsep
      [ "SCC:"
      , indent 2 $ encloseSep "{" "}" ", " (map prettyRule rules)
      ]

transformDPProblem
  :: (Ord f, Ord v, ToSExpr f, ToSExpr v)
  => DPProblem (FId f) (VId v)
  -> StateM [DPProblem (FId f) (VId v)]
transformDPProblem DPProblem{..} = do
  graph <- approximateDPGraphFromDPs strictrules (S.fromList $ mapMaybe R.definedSym weakrules)
  let sccs = stronglyConnComps graph
  return $ [DPProblem scc weakrules | SCC scc <- sccs]

transformDPProblemWithInfo
  :: (Ord f, Ord v, ToSExpr f, ToSExpr v)
  => DPProblem (FId f) (VId v)
  -> StateM ((DPGraph (FId f) (VId v), [SCC (FId f) (VId v)]), [DPProblem (FId f) (VId v)])
transformDPProblemWithInfo DPProblem{..} = do
  graph <- approximateDPGraphFromDPs strictrules (S.fromList $ mapMaybe R.definedSym weakrules)
  let sccs = stronglyConnComps graph
  return ((graph, sccs), [DPProblem scc weakrules | SCC scc <- sccs])

stronglyConnComps
  :: DPGraph f v
  -> [SCC f v]
stronglyConnComps dpgraph =
  mapMaybe decode forest
 where
  vertex_fn = fst3 . nodeFromVertex dpgraph
  forest = scc $ graph dpgraph

  decode (Node v [])
    | mentions_itself v = Just $ SCC [vertex_fn v]
    | otherwise = Nothing -- avoid acyclic ones here in DP analysis
  decode (Node v ts) = Just $ SCC (vertex_fn v : foldr dec [] ts)

  dec (Node v ts) vs = vertex_fn v : foldr dec vs ts
  mentions_itself v = v `elem` (graph dpgraph ! v)

approximateDPGraph
  :: (Ord f, Ord v, ToSExpr f, ToSExpr v)
  => LCTRS (FId f) (VId v)
  -> StateM (DPGraph (FId f) (VId v))
approximateDPGraph lctrs = do
  approximateDPGraphFromDPs (computeDPs lctrs) (definedSyms lctrs)

approximateDPGraphFromDPs
  :: (Ord f, Ord v, ToSExpr f, ToSExpr v)
  => [Rule (FId f) (VId v)]
  -> S.Set (FId f)
  -> StateM (DPGraph (FId f) (VId v))
approximateDPGraphFromDPs dps defSyms = do
  edgeList <-
    sequence
      [ (dp,key,) <$> edges
      | (key, dp) <- indexedDPs
      , let edges = edgesTo defSyms dp indexedDPs
      ]
  let (graph, nodeFromVertex, vertexFromKey) = graphFromEdges edgeList
  return $ DPGraph graph nodeFromVertex vertexFromKey m
 where
  indexedDPs = zip [1 ..] dps
  m = IM.fromList indexedDPs

edgesTo
  :: (Ord f, Ord v, ToSExpr f, ToSExpr v)
  => S.Set (FId f)
  -> Rule (FId f) (VId v)
  -> [(key, Rule (FId f) (VId v))]
  -> StateM [key]
edgesTo definedSyms dp1 indexedDPs = do
  map fst
    <$> filterM (\(_, dp2) -> hasEdgeFrom definedSyms dp1 dp2) indexedDPs

hasEdgeFrom
  :: (Ord f, Ord v, ToSExpr f, ToSExpr v)
  => S.Set (FId f)
  -> Rule (FId f) (VId v)
  -> Rule (FId f) (VId v)
  -> StateM Bool
hasEdgeFrom definedSyms dp1 dp2 = do
  -- liftIO $ print $ prettyRule dp1
  dp2' <- R.renameFresh dp2
  -- liftIO $ print $ prettyRule dp2
  -- liftIO $ print $ prettyRule dp2'
  let lvars = lvar dp1 `S.union` lvar dp2'
  let formula = psi lvars definedSyms (rhs dp1) (lhs dp2')
  case formula of
    Nothing -> return False
    Just pairs -> do
      let
        c1 = collapseGuardToTerm $ guard dp1
        c2 = collapseGuardToTerm $ guard dp2'
        satVars =
          S.fromList $
            T.vars c1
              <> T.vars c2
              <> foldMap (uncurry (<>) . bimap T.vars T.vars) pairs
        sexprs =
          [ toSExpr c1
          , toSExpr c2
          ]
            ++ [S.eq (toSExpr s) (toSExpr t) | (s, t) <- pairs]
      mbool <- satSExpr satVars $ S.andMany sexprs
      case mbool of
        Just x -> return x
        Nothing -> return True

prepareForSMTCheck
  :: (Ord v, Ord f, ToSExpr f, ToSExpr v)
  => Term (FId f) (VId v)
  -> Term (FId f) (VId v)
  -> Bool
prepareForSMTCheck s t = sort s == sort t

psi
  :: (Ord f, Ord v, ToSExpr f, ToSExpr v)
  => S.Set (VId v)
  -> S.Set (FId f)
  -> Term (FId f) (VId v)
  -> Term (FId f) (VId v)
  -> Maybe [(Term (FId f) (VId v), Term (FId f) (VId v))]
psi lvars definedSyms = psi'
 where
  psi' (Var s) _
    | s `S.notMember` lvars = return []
  psi' s@(Fun _ f _) _
    | f `S.member` definedSyms
        && not
          ( isLogicTerm s
              && S.fromList (vars s) `S.isSubsetOf` lvars
          ) =
        return []
  psi' s@(TheoryFun _ _) t
    | not (isValue s)
        && not (isLogicTerm s && S.fromList (vars s) `S.isSubsetOf` lvars)
        && (isVar t || isValue t) =
        return []
  psi' (TermFun f _) (Var v)
    | f `S.notMember` definedSyms
        && v `S.notMember` lvars =
        return []
  psi' (Val f) (Var v)
    | f `S.notMember` definedSyms
        && v `S.notMember` lvars =
        return []
  psi' (Fun _ f ss) (Fun _ g ts)
    | f == g && f `S.notMember` definedSyms =
        concat <$> zipWithM psi' ss ts
  psi' s t
    | isLogicTerm s
        && isLogicTerm t
        && S.fromList (T.vars s) `S.isSubsetOf` lvars
        && sort s == sort t =
        return [(s, t)]
  psi' _ _ = Nothing
