{- |
Module      : CriticalPairSplitting
Description :
Copyright   : (c) Jonas Schöpf, 2024
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides functions for splitting constrained critical pairs.
-}
module Rewriting.CriticalPairSplitting where

import Control.Monad (filterM)
import Data.Foldable (find)
import Data.LCTRS.FIdentifier (FId)
import Data.LCTRS.Guard (
  collapseGuardToTerm,
  conjGuards,
  createGuard,
  mapGuard,
 )
import Data.LCTRS.LCTRS (LCTRS, getRules)
import Data.LCTRS.Rule (rename, renameFresh)
import Data.LCTRS.Sort (Sorted)
import Data.LCTRS.Term (subterms)
import qualified Data.LCTRS.Term as T
import Data.LCTRS.VIdentifier (VId)
import Data.Maybe (isJust, isNothing, mapMaybe)
import Data.Monad (StateM)
import Data.SExpr (ToSExpr)
import Prettyprinter (Pretty)
import Rewriting.ConstrainedRewriting.ConstrainedRewriting (
  isReducibleInstance,
  matchesAnyPos,
 )
import Rewriting.CriticalPair (CriticalPair (..))
import Rewriting.Renaming (innerRenaming)
import Rewriting.Substitution (apply)

tryToClose
  :: (Pretty v, Pretty f, Ord v, Ord f, ToSExpr v, ToSExpr f)
  => ( LCTRS (FId f) (VId v)
       -> CriticalPair (FId f) (VId v) (VId v)
       -> StateM
            ( Maybe proof
            )
     )
  -> LCTRS (FId f) (VId v)
  -> CriticalPair (FId f) (VId v) (VId v)
  -> StateM
      ( Maybe
          [ ( CriticalPair (FId f) (VId v) (VId v)
            , proof
            )
          ]
      )
tryToClose solve lc cp = do
  mayVal <- solve lc cp
  case mayVal of
    Just p -> return $ Just [(cp, p)]
    Nothing -> do
      go [([], Just cp)]
 where
  go [] = return Nothing
  go branches = do
    steps <- concat <$> mapM (tryToCloseStep solve lc) branches
    case find (isNothing . snd) steps of
      Just (sols, _) -> return $ Just sols
      Nothing -> go steps

tryToCloseStep
  :: (Pretty v, Pretty v', Pretty f, Ord v, Ord v', Ord f, ToSExpr f, Sorted v')
  => (LCTRS (FId f) (VId v) -> CriticalPair (FId f) (VId v) v' -> StateM (Maybe b))
  -> LCTRS (FId f) (VId v)
  -> ([(CriticalPair (FId f) (VId v) v', b)], Maybe (CriticalPair (FId f) (VId v) v'))
  -> StateM
      [([(CriticalPair (FId f) (VId v) v', b)], Maybe (CriticalPair (FId f) (VId v) v'))]
tryToCloseStep _ _ (solutions, Nothing) = return [(solutions, Nothing)]
tryToCloseStep solve lc (solutions, Just cp) = do
  rules <- do
    let oldRules = getRules lc
    mapM (fmap (rename innerRenaming) . renameFresh) oldRules
  matchings <-
    mapM (matchesAnyPos (subterms (inner cp) ++ subterms (outer cp)) (constraint cp)) rules
  instances <- case mconcat matchings of
    [] -> return []
    instances -> filterM (isReducibleInstance (constraint cp)) instances
  solved <-
    mapM
      ( \(plain, subst, _) -> do
          let c = mapGuard (subst `apply`) plain
          let
            rule = cp{constraint = conjGuards c $ constraint cp}
            notRule = cp{constraint = conjGuards (createGuard [T.neg $ collapseGuardToTerm c]) $ constraint cp}
          isSolved <- solve lc rule
          isAlsoSolved <- solve lc notRule
          return ((rule, isSolved), (notRule, isAlsoSolved))
          -- sat <- satGuard (constraint notRule)
          -- case sat of
          --   Just False -> return ((rule, isSolved), Nothing)
          --   _ -> return ((rule, isSolved), Just notRule)
      )
      instances
  let results =
        mapMaybe
          -- (\((rule, mp), nR) -> case mp of Nothing -> Nothing; Just p -> Just ((rule, p) : solutions, nR))
          ( \((rule, isSolved), (notRule, isAlsoSolved)) -> do
              ruleProof <- isSolved
              case isAlsoSolved of
                Just notRuleProof -> return ((rule, ruleProof) : (notRule, notRuleProof) : solutions, Nothing)
                Nothing -> return ((rule, ruleProof) : solutions, Just notRule)
          )
          solved
  return results

tryToCloseOne
  :: (Pretty v, Pretty f, Ord v, Ord f, ToSExpr v, ToSExpr f)
  => ( LCTRS (FId f) (VId v)
       -> CriticalPair (FId f) (VId v) (VId v)
       -> StateM
            ( Maybe proof
            )
     )
  -> LCTRS (FId f) (VId v)
  -> CriticalPair (FId f) (VId v) (VId v)
  -> StateM
      ( Maybe
          ( Either
              ( CriticalPair (FId f) (VId v) (VId v)
              , proof
              )
              ( CriticalPair (FId f) (VId v) (VId v)
              , proof
              )
          )
      )
tryToCloseOne solve lc cp = do
  mayVal <- solve lc cp
  case mayVal of
    Just p -> return $ Just $ Left (cp, p)
    Nothing -> go [([], Just cp)]
 where
  go [] = return Nothing
  go branches = do
    steps <- concat <$> mapM (tryToCloseStep solve lc) branches
    -- case find (not . null . fst) steps of
    --   Nothing -> go $ filter (isJust . snd) steps
    --   Just (sols, _) -> return $ Just $ head sols
    case find (not . null . fst) steps of
      Nothing -> go $ filter (isJust . snd) steps
      Just (sols, _) -> return $ Just $ Right $ head sols

-- goStep (solutions, Nothing) = return [(solutions, Nothing)]
-- goStep (solutions, Just cp) = do
--   rules <- do
--     let oldRules = getRules lc
--     mapM (fmap (rename innerRenaming) . renameFresh) oldRules
--   matchings <-
--     mapM (matchesAnyPos (subterms (inner cp) ++ subterms (outer cp)) (constraint cp)) rules
--   instances <- case mconcat matchings of
--     [] -> return []
--     instances -> filterM (isReducibleInstance (constraint cp)) instances
--   solved <-
--     mapM
--       ( \(plain, subst, _) -> do
--           let c = mapGuard (subst `apply`) plain
--           let
--             rule = cp{constraint = conjGuards c $ constraint cp}
--             notRule = cp{constraint = conjGuards (createGuard [T.neg $ collapseGuardToTerm c]) $ constraint cp}
--           isSolved <- solve lc rule
--           sat <- satGuard (constraint notRule)
--           case sat of
--             Just False -> return ((rule, isSolved), Nothing)
--             _ -> return ((rule, isSolved), Just notRule)
--       )
--       instances
--   let results =
--         mapMaybe
--           (\((rule, mp), nR) -> case mp of Nothing -> Nothing; Just p -> Just ((rule, p) : solutions, nR))
--           solved
--   return results
