{-# LANGUAGE DeriveTraversable #-}

{- |
Module      : Polynomial
Description :
Copyright   : (c) Jonas Schöpf, 2024
License     : GPL-3
Maintainer  : jonas.schoepf@uibk.ac.at
Stability   : stable


This module provides the main types and auxiliary functions for polynomials.
-}
module Data.Polynomial.Polynomial where

import Control.Monad.Identity (Identity (..))
import qualified Data.Map.Merge.Strict as M
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe)
import qualified Data.MultiSet as MS
import Data.Polynomial.Monomial (Monomial (..), toPowers)
import qualified Data.Set as S

newtype Polynomial v c = Polynomial (M.Map (Monomial v) c)
  deriving (Eq, Ord, Functor, Foldable, Traversable)

-- polynomials

unit :: Polynomial c v
unit = Polynomial M.empty

coefficient :: (Eq c, Num c) => c -> Polynomial v c
coefficient c = norm (Polynomial (M.singleton (Monomial MS.empty) c))

variable :: (Num c) => v -> Polynomial v c
variable v = Polynomial (M.singleton (Monomial (MS.singleton v)) 1)

fromMono :: (Num c) => Monomial v -> Polynomial v c
fromMono m = Polynomial (M.singleton m 1)

fromMonos :: (Ord v) => [(c, Monomial v)] -> Polynomial v c
fromMonos ms = Polynomial (M.fromList [(m, c) | (c, m) <- ms])

toMonos :: Polynomial v c -> [(c, Monomial v)]
toMonos (Polynomial ms) = [(c, m) | (m, c) <- M.toList ms]

toMonoMap :: Polynomial v c -> M.Map (Monomial v) c
toMonoMap (Polynomial m) = m

coefficientOf :: (Ord v, Num c) => Monomial v -> Polynomial v c -> c
coefficientOf m (Polynomial p) = fromMaybe 0 (M.lookup m p)

coefficients :: Polynomial v c -> [c]
coefficients (Polynomial ms) = M.elems ms

variables :: (Ord v) => Polynomial v c -> [v]
variables p = S.toAscList (S.unions [MS.toSet m | (_, Monomial m) <- toMonos p])

rename :: (Eq c, Num c, Ord v') => (v -> v') -> Polynomial v c -> Polynomial v' c
rename f (Polynomial ms) = norm $ Polynomial (M.mapKeysWith (+) (\(Monomial m) -> Monomial (MS.map f m)) ms)

norm :: (Num c, Eq c) => Polynomial v c -> Polynomial v c
norm (Polynomial ms) = Polynomial (M.filter (0 /=) ms)

isZero :: (Num c, Eq c) => Polynomial v c -> Bool
isZero p = and [c == 0 || null (toPowers m) | (c, m) <- toMonos p]

zeroPoly :: Polynomial v c
zeroPoly = Polynomial M.empty

instance (Eq c, Ord v, Num c) => Num (Polynomial v c) where
  Polynomial ms1 + Polynomial ms2 = norm $ Polynomial (M.unionWith (+) ms1 ms2)
  Polynomial ms1 * Polynomial ms2 = norm $ Polynomial (M.fromListWith (+) ms)
   where
    ms =
      [ (Monomial (m1 `MS.union` m2), c1 * c2)
      | (Monomial m1, c1) <- M.toList ms1
      , (Monomial m2, c2) <- M.toList ms2
      ]
  negate = fmap negate
  fromInteger = coefficient . fromInteger
  abs = error "Polynomial.Num.abs: not defined."
  signum = error "Polynomial.Num.signum: not defined."

-- eval

substitute :: (Eq c, Num c, Ord v') => (v -> Polynomial v' c) -> Polynomial v c -> Polynomial v' c
substitute s = substPoly
 where
  substPoly p = sum [coefficient c * substMono m | (c, m) <- toMonos p]
  substMono m = product [product $ replicate p (s v) | (v, p) <- toPowers m]

fromPolynomialM :: (Num a, Applicative f) => (v -> f a) -> (c -> f a) -> Polynomial v c -> f a
fromPolynomialM var coeff = evalPoly
 where
  evalPoly p = sum <$> sequenceA [(*) <$> coeff c <*> evalMono m | (c, m) <- toMonos p]
  evalMono m = product <$> sequenceA [(^) <$> var v <*> pure p | (v, p) <- toPowers m]

fromPolynomial :: (Num a) => (v -> a) -> (c -> a) -> Polynomial v c -> a
fromPolynomial var coeff = runIdentity . fromPolynomialM (pure . var) (pure . coeff)

evalWithM :: (Num c, Applicative m) => (v -> m c) -> Polynomial v c -> m c
evalWithM getValue = evalPoly
 where
  evalPoly p = sum <$> sequenceA [(c *) <$> evalMono m | (c, m) <- toMonos p]
  evalMono m = product <$> sequenceA [(^ p) <$> getValue v | (v, p) <- toPowers m]

evalWith :: (Num c) => (v -> c) -> Polynomial v c -> c
evalWith getValue = runIdentity . evalWithM (return . getValue)

zipCoefficientsWith
  :: (Ord v)
  => (c1 -> c3)
  -> (c2 -> c3)
  -> (c1 -> c2 -> c3)
  -> Polynomial v c1
  -> Polynomial v c2
  -> Polynomial v c3
zipCoefficientsWith f1 f2 f (Polynomial m1) (Polynomial m2) =
  Polynomial $
    M.merge
      (M.mapMissing (\_ c1 -> f1 c1))
      (M.mapMissing (\_ c2 -> f2 c2))
      (M.zipWithMatched (\_ c1 c2 -> f c1 c2))
      m1
      m2

zipCoefficients :: (Ord v, Num c1, Num c2) => Polynomial v c1 -> Polynomial v c2 -> [(c1, c2)]
zipCoefficients p1 p2 =
  coefficients $
    zipCoefficientsWith (,0) (0,) (,) p1 p2
