module SumSqs3;

data Nat = Zero | S(Nat);

def Nat plus(Nat x, Nat y) =
    case x {
    Zero => y;
    S(xP) => S(plus(xP, y));
};

def Nat mult(Nat x, Nat y) =
    case x {
    Zero => Zero;
    S(xP) => plus(y, mult(xP, y));
};

def Nat square(Nat x) = mult(x,x);

def List<Nat> unfoldr(f)(Nat z) =
    case f(z) {
    Nothing => Nil;
    Just(zP) => Cons(zP, unfoldr(zP));
};

def Maybe<Nat> countdown(Nat m) =
    case m {
    Zero => Nothing;
    S(mP) => Just(mP);
};

def List<Nat> enum(Nat n) =
    case n {
    Zero => Nil;
    S(nP) => Cons(n, unfoldr(countdown)(n));
};


def Nat sum(List<Nat> xs) =
    case xs {
    Nil => Zero;
    Cons(x,xsP) => plus(x, sum(xsP));
};

def Nat sum_sqs(Nat n) = sum(map(square)(enum(n)));

def Nat start(Nat n) = sum_sqs(n);


{

}