from models.cnn_fully_connected_loss_weights import DeepPhiModel

import tqdm
from pre_processing import *
import tensorflow as tf


def evaluation(evaluation_data, max_examples_per_prop=20, embedding_dim=128, max_len=128, num_tokens=80,
               learning_rate=0.001, decay_rate=0.9, decay_steps=200, vocab_size=1024,
               log_dir='models/logdir', verbose=False, encoding_projection=False, projection_dim=256):

    sub_formula_files = list_files_from_directory(os.path.join(evaluation_data, DataDirNames.SUBFORMULAS.value))
    mp_files = list_files_from_directory(os.path.join(evaluation_data, DataDirNames.MODUS_PONENS.value))
    all_term_forms = list_files_from_directory(os.path.join(evaluation_data, DataDirNames.TERM_FORMULA.value))
    unifiability_files = list_files_from_directory(os.path.join(evaluation_data, DataDirNames.UNIFIABILITY.value))
    well_formed_files = list_files_from_directory(os.path.join(evaluation_data, DataDirNames.WELL_FORMEDNESS.value))
    alpha_equiv_files = list_files_from_directory(os.path.join(evaluation_data, DataDirNames.ALPHA_EQUIVALENCE.value))

    print("\n\n############### Starting Evaluation... ###############")

    data_obj = TrainingFormulas.load_training_formulas(file=os.path.join(log_dir, 'training_formulas.pkl'),
                                                       vocab_size=vocab_size, max_len=max_len)
    # Pre-processing

    # sub term property and labels
    sub_form_training_list: List[Tuple[List[int], List[int]]] = \
        sub_term_pre_eval_processing(data_obj, sub_formula_files)
    shuffle(sub_form_training_list)

    sf_bin_pos, sf_bin_neg = sub_formula_binary_pre_processing_with_encoding(data_obj, sub_formula_files)
    shuffle(sf_bin_pos)
    shuffle(sf_bin_neg)
    num_sf_bin = min(len(sf_bin_pos), len(sf_bin_neg))
    sf_bin_pos, sf_bin_neg = sf_bin_pos[:num_sf_bin], sf_bin_neg[:num_sf_bin]
    sub_formula_binary_all = [(1, x, y) for x, y in sf_bin_pos
                              if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len] + \
                             [(0, x, y) for x, y in sf_bin_neg
                              if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len]
    shuffle(sub_formula_binary_all)

    # pos and neg can be of different lengths,
    # these are lists of pairs where the pairs in pos are positive classification
    modus_ponens_pos_training, modus_ponens_neg_training = modus_ponens_pre_processing(data_obj, mp_files)
    mp_all_data = [(1, x, y) for x, y in modus_ponens_pos_training
                   if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len] + \
                  [(0, x, y) for x, y in modus_ponens_neg_training
                   if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len]
    shuffle(mp_all_data)

    # term formula classification data
    te_fo_forms, te_fo_terms = formula_term_class_prep(data_obj, all_term_forms)
    all_term_forms = [(1, t) for t in te_fo_terms if len(t) <= data_obj.max_word_len] + \
                     [(0, t) for t in te_fo_forms if len(t) <= data_obj.max_word_len]
    shuffle(all_term_forms)

    # unifiability
    unifiable, not_unifiable = unifiability_classifier_data_generation(data_obj, unifiability_files)
    unifiability_all_data = [(1, x, y) for x, y in unifiable
                             if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len] + \
                            [(0, x, y) for x, y in not_unifiable
                             if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len]
    shuffle(unifiability_all_data)

    # well formedness
    well_formed, not_well_formed = well_formedness_pre_processing_with_encoding(data_obj, well_formed_files)

    w_formedness_all_data = [(1, x) for x in well_formed if len(x) <= data_obj.max_word_len] + \
                            [(0, x) for x in not_well_formed if len(x) <= data_obj.max_word_len]
    shuffle(w_formedness_all_data)

    # Alpha equivalence
    alpha_equiv, alpha_not_equiv = alpha_equivalence_pre_processing_with_encoding(data_obj, alpha_equiv_files)
    alpha_all_data: List[Tuple[int, List[int], List[int]]] = \
        [(1, x, y) for x, y in alpha_equiv if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len] + \
        [(0, x, y) for x, y in alpha_not_equiv if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len]
    shuffle(alpha_all_data)

    print("\n####### Evaluation Data Information ######")
    print("sub_formula_training_list size: {}".format(len(sub_form_training_list)))

    print("sub_formula_binary_all size: {}".format(len(sub_formula_binary_all)))
    print("{:>7}:{:<7}\t approximate ratio +:-".format(len(sf_bin_pos), len(sf_bin_neg)))

    print("mp_all_data size: {}".format(len(mp_all_data)))
    print("{:>7}:{:<7}\t approximate ratio +:-".format(len(modus_ponens_pos_training),
                                                       len(modus_ponens_neg_training)))

    print("all_term_forms size: {}".format(len(all_term_forms)))
    print("{:>7}:{:<7}\t approximate ratio forms:terms".format(len(te_fo_forms), len(te_fo_terms)))

    print("unifiability_all_data size: {}".format(len(unifiability_all_data)))
    print("{:>7}:{:<7}\t approximate ratio +:-".format(len(unifiable), len(not_unifiable)))

    print("w_formedness_all_data size: {}".format(len(w_formedness_all_data)))
    print("{:>7}:{:<7}\t approximate ratio +:-".format(len(well_formed), len(not_well_formed)))

    print("alpha_all_data size: {}".format(len(alpha_all_data)))
    print("{:>7}:{:<7}\t approximate ratio +:-".format(len(alpha_equiv), len(alpha_not_equiv)))

    print("\n Randomly choosing at most {} examples for evaluation.".format(max_examples_per_prop))
    batch_subterm = sub_form_training_list[:max_examples_per_prop]
    batch_mp = mp_all_data[:max_examples_per_prop]
    batch_term_forms = all_term_forms[:max_examples_per_prop]
    batch_unifiability = unifiability_all_data[:max_examples_per_prop]
    batch_subform_binary = sub_formula_binary_all[:max_examples_per_prop]
    batch_well_formed = w_formedness_all_data[:max_examples_per_prop]
    batch_alpha_equiv: List[Tuple[int, List[int], List[int]]] = alpha_all_data[:max_examples_per_prop]

    print("Loading graph...")

    g = tf.Graph()
    with g.as_default():
        # encoder
        small_enc = DeepPhiModel(embedding_dim, DeephiMode.DEFAULT,
                                 sub_fm_multi_label_classes=vocab_size,
                                 max_len=max_len,
                                 uniform_init_scale=0.1,
                                 learning_rate=learning_rate,
                                 decay_steps=decay_steps,
                                 decay_rate=decay_rate,
                                 num_tokens=num_tokens,
                                 mp_classifier=True,
                                 term_formula_classifier=True,
                                 unification_classifier=True,
                                 binary_subform_classifier=True,
                                 well_formed_classifier=True,
                                 alpha_equi_classifier=True,
                                 encoding_projection=encoding_projection,
                                 encoding_projection_dim=projection_dim)

        small_enc.build_model()
        saver = tf.train.Saver()

    with tf.Session(graph=g) as session:
        # restore last checkpoint (one could alos average the last few checkpoint weights)
        saver.restore(session, tf.train.latest_checkpoint(log_dir))
        print("Restored graph...")

        # ### DATA Preparation ###

        # subterm
        x = np.array([np.array(b[0]) for b in batch_subterm])
        y = np.array([b[1] for b in batch_subterm])
        lengths = np.array([len(xi) for xi in x])

        # binary subformula
        bin_sub_fm_x = np.array([np.array(b[1] + ((data_obj.max_word_len - len(b[1])) * [0]))
                                 for b in batch_subform_binary])
        lengths_bin_sub_fm_x = np.array([len(xi) for xi in bin_sub_fm_x])

        bin_sub_fm_y = np.array([np.array(b[2] + ((data_obj.max_word_len - len(b[2])) * [0]))
                                 for b in batch_subform_binary])
        lengths_bin_sub_fm_y = np.array([len(xi) for xi in bin_sub_fm_y])

        bin_sub_fm_labels = np.array([l[0] for l in batch_subform_binary])

        # mp
        mp_a: np.ndarray = np.array([np.array(a[1] + ((data_obj.max_word_len - len(a[1])) * [0]))
                                     for a in batch_mp])
        lengths_a = np.array([len(a) for a in mp_a])
        mp_b = np.array([np.array(b[2] + ((data_obj.max_word_len - len(b[2])) * [0])) for b in batch_mp])
        lengths_b = np.array([len(a) for a in mp_b])
        mp_labels = np.array([l[0] for l in batch_mp])

        # forms and terms
        tfs: np.ndarray = np.array([np.array(a[1] + ((data_obj.max_word_len - len(a[1])) * [0]))
                                    for a in batch_term_forms])
        tf_lengths = np.array([len(a[1]) for a in batch_term_forms])
        tf_labels = np.array([a[0] for a in batch_term_forms])

        # unifiability
        unifiability_a: np.ndarray = np.array([np.array(a[1] + ((data_obj.max_word_len - len(a[1])) * [0]))
                                               for a in batch_unifiability])
        lengths_unifiability_a = np.array([len(a[1]) for a in batch_unifiability])
        unifiability_b: np.ndarray = np.array([np.array(a[2] + ((data_obj.max_word_len - len(a[2])) * [0]))
                                               for a in batch_unifiability])
        lengths_unifiability_b = np.array([len(a[2]) for a in batch_unifiability])
        unifiability_labels = np.array([l[0] for l in batch_unifiability])

        # Well formedness
        wf_input: np.ndarray = np.array([np.array(a[1] + ((data_obj.max_word_len - len(a[1])) * [0]))
                                         for a in batch_well_formed])
        wf_labels: np.ndarray = np.array([a[0] for a in batch_well_formed])
        wf_length: np.ndarray = np.array([len(a[1]) for a in batch_well_formed])

        # Alpha Equivalence
        alpha_a: np.ndarray = np.array([np.array(a[1] + ((data_obj.max_word_len - len(a[1])) * [0]))
                                        for a in batch_alpha_equiv])
        alpha_a_length = np.array([len(a[1]) for a in batch_alpha_equiv])
        alpha_b: np.ndarray = np.array([np.array(a[2] + ((data_obj.max_word_len - len(a[2])) * [0]))
                                        for a in batch_alpha_equiv])
        alpha_b_length = np.array([len(a[2]) for a in batch_alpha_equiv])
        alpha_equiv_labels = np.array([l[0] for l in batch_alpha_equiv])

        feed_dict = {
            small_enc.general_encoder_input: x,
            small_enc.labels: y,
            small_enc.encode_length: lengths,
            small_enc.binary_subform_input_a: bin_sub_fm_x,
            small_enc.binary_subform_input_b: bin_sub_fm_y,
            small_enc.binary_subform_length_a: lengths_bin_sub_fm_x,
            small_enc.binary_subform_length_b: lengths_bin_sub_fm_y,
            small_enc.binary_subform_labels: bin_sub_fm_labels,
            small_enc.mp_input_form_a: mp_a,
            small_enc.mp_input_form_b: mp_b,
            small_enc.mp_second_input_length_a: lengths_a,
            small_enc.mp_second_input_length_b: lengths_b,
            small_enc.mp_class_labels: mp_labels,
            small_enc.term_formula_input: tfs,
            small_enc.term_formula_length: tf_lengths,
            small_enc.term_formula_labels: tf_labels,
            small_enc.unification_input_a: unifiability_a,
            small_enc.unification_input_b: unifiability_b,
            small_enc.unification_input_a_length: lengths_unifiability_a,
            small_enc.unification_input_b_length: lengths_unifiability_b,
            small_enc.unification_labels: unifiability_labels,
            small_enc.hopefully_well_formed_input: wf_input,
            small_enc.hopefully_well_formed_lengths: wf_length,
            small_enc.hopefully_well_formed_labels: wf_labels,
            small_enc.alpha_equi_input_a: alpha_a,
            small_enc.alpha_equi_input_b: alpha_b,
            small_enc.alpha_equi_length_a: alpha_a_length,
            small_enc.alpha_equi_length_b: alpha_b_length,
            small_enc.alpha_equi_labels: alpha_equiv_labels
        }

        highest_ind, sub_acc, preds, mp_acc, tf_preds, tf_acc, unif_preds, unif_acc, bin_sub_preds, bin_sub_acc, \
            well_f_preds, well_f_acc, alpha_equiv_preds, alpha_equiv_acc = session.run(
                [
                    small_enc.highest_indices, small_enc.accuracy,
                    small_enc.binary_predictions, small_enc.binary_accuracy,
                    small_enc.term_formula_predictions, small_enc.term_formula_accuracy,
                    small_enc.uni_predictions, small_enc.uni_accuracy,
                    small_enc.bin_sub_predictions, small_enc.bin_sub_accuracy,
                    small_enc.well_formed_predictions, small_enc.well_formed_accuracy,
                    small_enc.alpha_equi_predictions, small_enc.alpha_equi_accuracy
                ], feed_dict=feed_dict)

        if verbose:
            print("######## SUBFORMULA \t acc:{: <20}\n".format(sub_acc))
            for p, l in zip(x, highest_ind):
                print("\nFormula: {}".format(data_obj.string_of_encoding(p)))
                for i in l:
                    print("\t{}".format(data_obj.word_of_key(i)))

            print("\n\n######## MODUS PONENS EVALUATION \t acc:{: <20}\n".format(mp_acc))
            for i in range(len(preds)):
                print("Term 1:\t{}\n".format(''.join(data_obj.string_of_encoding(mp_a[i]))))
                print("Term 2:\t{}\n".format(''.join(data_obj.string_of_encoding(mp_b[i]))))
                print("Prediction:\t{}\t{}".format(preds[i], np.argmax(preds[i])))
                print("Actual: {}".format(mp_labels[i]))
                print("\n")

            print("######## TERM or FORMULA \t acc: {: <20}\n".format(tf_acc))
            print("predictions:\n")
            for i in range(len(tf_preds)):
                print("Formula vs Term:\t {}".format(''.join(data_obj.string_of_encoding(tfs[i]))))
                print("prediction:\t {} \t {}".format(tf_preds[i], np.argmax(tf_preds[i])))
                print("Actual: {}".format(tf_labels[i]))
                print("\n")

            # print unification here just like in the examples above
            print("######## UNIF \t acc: {: <20}\n".format(unif_acc))
            print("predictions:\n")
            for i in range(len(unif_preds)):
                print("Term 1:\t {}".format(''.join(data_obj.string_of_encoding(unifiability_a[i]))))
                print("Term 2:\t {}".format(''.join(data_obj.string_of_encoding(unifiability_b[i]))))
                print("prediction:\t {} \t {}".format(unif_preds[i], np.argmax(unif_preds[i])))
                print("Actual: {}".format(unifiability_labels[i]))
                print("\n")

            print("######## Binary Sub-Formula \t acc: {: <20}\n".format(bin_sub_acc))
            print("predictions:\n")
            for i in range(len(bin_sub_preds)):
                print("Formula 1:\t {}".format(''.join(data_obj.string_of_encoding(bin_sub_fm_x[i]))))
                print("Formula 2:\t {}".format(''.join(data_obj.string_of_encoding(bin_sub_fm_y[i]))))
                print("prediction:\t {} \t {}".format(bin_sub_preds[i], np.argmax(bin_sub_preds[i])))
                print("Actual: {}".format(bin_sub_fm_labels[i]))
                print("\n")

            print("######## Well-Formedness \t acc: {: <20}\n".format(well_f_acc))
            print("predictions:\n")
            for i in range(len(well_f_preds)):
                print("Formula :\t {}".format(''.join(data_obj.string_of_encoding(wf_input[i]))))
                print("prediction:\t {} \t {}".format(well_f_preds[i], np.argmax(well_f_preds[i])))
                print("Actual: {}".format(wf_labels[i]))
                print("\n")

            print("########  Alpha Equivalent \t acc: {: <20}\n".format(alpha_equiv_acc))
            print("predictions:\n")
            for i in range(len(alpha_equiv_preds)):
                print("Formula 1:\t {}".format(''.join(data_obj.string_of_encoding(alpha_a[i]))))
                print("Formula 2:\t {}".format(''.join(data_obj.string_of_encoding(alpha_b[i]))))
                print("prediction:\t {} \t {}".format(alpha_equiv_preds[i], np.argmax(alpha_equiv_preds[i])))
                print("Actual: {}".format(alpha_equiv_labels[i]))
                print("\n")

        print("#### EVALUATION Results Overview ####")
        print("Sub-Formula multilabel accuracy:\t{: <20}".format(sub_acc))
        print("Binary Sub-Formula accuracy:\t{: <20}".format(bin_sub_acc))
        print("Modus Ponens accuracy accuracy:\t{: <20}".format(mp_acc))
        print("Term vs Formula classification accuracy:\t{: <20}".format(tf_acc))
        print("Unifiability accuracy:\t{: <20}".format(unif_acc))
        print("Well-Formedness accuracy:\t{: <20}".format(well_f_acc))
        print("Alpha Equivalent accuracy:\t{: <20}".format(alpha_equiv_acc))
        print("\n")

    session.close()


def train(training_folder, num_epoch: int, batch_size: int = 16, embedding_dim=128, max_len=128, num_tokens=80,
          learning_rate=0.001, decay_rate=0.9, decay_steps=200, vocab_size=1024, log_dir='models/logdir',
          verbose=False, encoding_projection=False, projection_dim=256):
    # Create Training Formulas object
    # Don't forget to delete pkl file if changed

    sub_formula_files = list_files_from_directory(os.path.join(training_folder, DataDirNames.SUBFORMULAS.value))
    mp_files = list_files_from_directory(os.path.join(training_folder, DataDirNames.MODUS_PONENS.value))
    all_term_forms = list_files_from_directory(os.path.join(training_folder, DataDirNames.TERM_FORMULA.value))
    unifiability_files = list_files_from_directory(os.path.join(training_folder, DataDirNames.UNIFIABILITY.value))
    well_formed_files = list_files_from_directory(os.path.join(training_folder, DataDirNames.WELL_FORMEDNESS.value))
    alpha_equiv_files = list_files_from_directory(os.path.join(training_folder, DataDirNames.ALPHA_EQUIVALENCE.value))

    data_obj = TrainingFormulas.load_training_formulas(file=os.path.join(log_dir, 'training_formulas.pkl'),
                                                       vocab_size=vocab_size, max_len=max_len)

    # Pre-processing

    print("Reading data...")
    lst = [
        (sub_formula_pre_processing, data_obj, sub_formula_files),
        (modus_ponens_pre_processing, data_obj, mp_files),
        (formula_term_class_prep, data_obj, all_term_forms),
        (unifiability_classifier_data_generation, data_obj, unifiability_files),
        (sub_formula_binary_pre_processing_with_encoding, data_obj, sub_formula_files),
        (well_formedness_pre_processing_with_encoding, data_obj, well_formed_files),
        (alpha_equivalence_pre_processing_with_encoding, data_obj, alpha_equiv_files)
    ]
    res: List = list(map(lambda x: x[0](x[1], x[2]), lst))

    # sub term property and labels
    sub_formula_training_list: List[Tuple[List[int], List[int]]] = res[0]
    shuffle(sub_formula_training_list)
    # read binary sub formula property and shuffle and make it evenly split
    sf_bin_pos, sf_bin_neg = res[4]
    shuffle(sf_bin_pos)
    shuffle(sf_bin_neg)
    num_sf_bin = min(len(sf_bin_pos), len(sf_bin_neg))
    sf_bin_pos = sf_bin_pos[:num_sf_bin]
    sf_bin_neg = sf_bin_neg[:num_sf_bin]
    sub_formula_binary_all = [(1, x, y) for x, y in sf_bin_pos
                              if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len] + \
                             [(0, x, y) for x, y in sf_bin_neg
                              if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len]
    shuffle(sub_formula_binary_all)

    # pos and neg can be of different lengths,
    # these are lists of pairs where the pairs in pos are positive classification
    modus_ponens_pos_training, modus_ponens_neg_training = res[1]
    mp_all_data = [(1, x, y) for x, y in modus_ponens_pos_training
                   if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len] + \
                  [(0, x, y) for x, y in modus_ponens_neg_training
                   if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len]
    shuffle(mp_all_data)

    # term formula classification data
    te_fo_forms, te_fo_terms = res[2]
    all_term_forms = [(1, t) for t in te_fo_terms if len(t) <= data_obj.max_word_len] + \
                     [(0, t) for t in te_fo_forms if len(t) <= data_obj.max_word_len]
    shuffle(all_term_forms)

    # unifiability
    unifiable, not_unifiable = res[3]
    unifiability_all_data = [(1, x, y) for x, y in unifiable
                             if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len] + \
                            [(0, x, y) for x, y in not_unifiable
                             if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len]
    shuffle(unifiability_all_data)

    # well formedness
    well_formed, not_well_formed = res[5]
    w_formedness_all_data = [(1, x) for x in well_formed if len(x) <= data_obj.max_word_len] + \
                            [(0, x) for x in not_well_formed if len(x) <= data_obj.max_word_len]
    shuffle(w_formedness_all_data)

    # Alpha equivalence
    alpha_equiv, alpha_not_equiv = res[6]
    alpha_all_data: List[Tuple[int, List[int], List[int]]] = \
        [(1, x, y) for x, y in alpha_equiv if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len] +\
        [(0, x, y) for x, y in alpha_not_equiv if len(x) <= data_obj.max_word_len and len(y) <= data_obj.max_word_len]
    shuffle(alpha_all_data)

    print("\nWord to int size: {}\n".format(len(data_obj.word_to_int)))

    # reducing data:
    sub_formula_binary_all = sub_formula_binary_all[:len(sub_formula_training_list)*6]
    # unifiability_all_data = unifiability_all_data[:len(unifiability_all_data)//10]
    # sub_formula_training_list = sub_formula_training_list[:len(sub_formula_training_list)//10]
    # all_term_forms = all_term_forms[:len(all_term_forms)//10]
    # unifiability_all_data = unifiability_all_data[:len(unifiability_all_data)//10]
    # mp_all_data = mp_all_data[:len(mp_all_data)//10]
    # w_formedness_all_data = w_formedness_all_data[:len(w_formedness_all_data)//10]

    min_data = min(len(sub_formula_training_list),
                   len(sub_formula_binary_all),
                   len(mp_all_data),
                   len(all_term_forms),
                   len(unifiability_all_data),
                   len(w_formedness_all_data),
                   len(alpha_all_data))
    num_batches = min_data // batch_size

    print("\n\n####### Training Data Information ######")
    print("sub_formula_training_list size: {}".format(len(sub_formula_training_list)))

    print("sub_formula_binary_all size: {}".format(len(sub_formula_binary_all)))
    print("{:>7}:{:<7}\t approximate ratio +:-".format(len(sf_bin_pos), len(sf_bin_neg)))

    print("mp_all_data size: {}".format(len(mp_all_data)))
    print("{:>7}:{:<7}\t approximate ratio +:-".format(len(modus_ponens_pos_training),
                                                       len(modus_ponens_neg_training)))

    print("all_term_forms size: {}".format(len(all_term_forms)))
    print("{:>7}:{:<7}\t approximate ratio forms:terms".format(len(te_fo_forms), len(te_fo_terms)))

    print("unifiability_all_data size: {}".format(len(unifiability_all_data)))
    print("{:>7}:{:<7}\t approximate ratio +:-".format(len(unifiable), len(not_unifiable)))

    print("w_formedness_all_data size: {}".format(len(w_formedness_all_data)))
    print("{:>7}:{:<7}\t approximate ratio +:-".format(len(well_formed), len(not_well_formed)))

    print("alpha_all_data size: {}".format(len(alpha_all_data)))
    print("{:>7}:{:<7}\t approximate ratio +:-".format(len(alpha_equiv), len(alpha_not_equiv)))

    print("\n\n")

    # chunk data
    batches_subform = list(split_into(sub_formula_training_list, num_batches))
    batches_subform_binary = list(split_into(sub_formula_binary_all, num_batches))
    batches_mp = list(split_into(mp_all_data, num_batches))
    batches_terms_forms = list(split_into(all_term_forms, num_batches))
    batches_unifiability = list(split_into(unifiability_all_data, num_batches))
    batches_well_formed = list(split_into(w_formedness_all_data, num_batches))
    batches_alpha_equivalence: List[List[Tuple[int, List[int], List[int]]]] = \
        list(split_into(alpha_all_data, num_batches))

    g = tf.Graph()
    with g.as_default():
        # encoder
        small_enc = DeepPhiModel(embedding_dim, DeephiMode.DEFAULT,
                                 sub_fm_multi_label_classes=vocab_size,
                                 max_len=max_len,
                                 uniform_init_scale=0.1,
                                 learning_rate=learning_rate,
                                 decay_steps=decay_steps,
                                 decay_rate=decay_rate,
                                 num_tokens=num_tokens,
                                 mp_classifier=True,
                                 term_formula_classifier=True,
                                 unification_classifier=True,
                                 binary_subform_classifier=True,
                                 well_formed_classifier=True,
                                 alpha_equi_classifier=True,
                                 encoding_projection=encoding_projection,
                                 encoding_projection_dim=projection_dim)

        small_enc.build_model()
        saver = tf.train.Saver()
        init = tf.global_variables_initializer()
    print("Starting training...")
    with tf.Session(graph=g) as session:
        writer = tf.summary.FileWriter(log_dir, session.graph)
        init.run()

        sum_epoch_loss: float = 0.0
        avg_epoch_mp_acc: float = 0.0
        avg_epoch_ft_acc: float = 0.0
        avg_epoch_uni_acc: float = 0.0
        avg_epoch_wellf_acc: float = 0.0
        avg_epoch_alpha_acc: float = 0.0
        avg_epoch_binary_sub: float = 0.0

        # all_x_subterms = remove_duplicates_in_list_list([x for x, _ in sub_formula_training_list])
        # thought_vecs = []
        for e in tqdm.tqdm(range(num_epoch), unit='epoch'):
            # run session
            # Other stuff
            print("Training Epoch {:>5}/{:<5}".format(e+1, num_epoch))
            for i in tqdm.tqdm(range(0, num_batches), unit='batch'):
                batch_subform = batches_subform[i]
                batch_mp = batches_mp[i]
                batch_term_forms = batches_terms_forms[i]
                batch_unifiability = batches_unifiability[i]
                batch_subform_binary = batches_subform_binary[i]
                batch_well_formed = batches_well_formed[i]
                batch_alpha_equiv: List[Tuple[int, List[int], List[int]]] = batches_alpha_equivalence[i]
                # ### DATA Preparation ###

                # sub formula
                x = np.array([b[0] for b in batch_subform])
                y = np.array([b[1] for b in batch_subform])
                lengths = np.array([len(xi) for xi in x])

                # binary subformula
                sub_fm_x = np.array([np.array(b[1] + ((data_obj.max_word_len - len(b[1])) * [0]))
                                     for b in batch_subform_binary])
                lengths_sub_fm_x = np.array([len(xi) for xi in sub_fm_x])

                sub_fm_y = np.array([np.array(b[2] + ((data_obj.max_word_len - len(b[2])) * [0]))
                                     for b in batch_subform_binary])
                lengths_sub_fm_y = np.array([len(xi) for xi in sub_fm_y])

                sub_fm_labels = np.array([l[0] for l in batch_subform_binary])

                # ####### mp
                mp_a: np.ndarray = np.array([np.array(a[1] + ((data_obj.max_word_len - len(a[1])) * [0]))
                                             for a in batch_mp])
                lengths_a = np.array([len(a) for a in mp_a])
                mp_b = np.array([np.array(b[2] + ((data_obj.max_word_len - len(b[2])) * [0])) for b in batch_mp])
                lengths_b = np.array([len(a) for a in mp_b])
                mp_labels = np.array([l[0] for l in batch_mp])

                # forms and terms
                tfs: np.ndarray = np.array([np.array(a[1] + ((data_obj.max_word_len - len(a[1])) * [0]))
                                            for a in batch_term_forms])
                tf_lengths = np.array([len(a[1]) for a in batch_term_forms])
                tf_labels = np.array([a[0] for a in batch_term_forms])

                # unifiability
                unifiability_a: np.ndarray = np.array([np.array(a[1] + ((data_obj.max_word_len - len(a[1])) * [0]))
                                                       for a in batch_unifiability])
                lengths_unifiability_a = np.array([len(a[1]) for a in batch_unifiability])
                unifiability_b: np.ndarray = np.array([np.array(a[2] + ((data_obj.max_word_len - len(a[2])) * [0]))
                                                       for a in batch_unifiability])
                lengths_unifiability_b = np.array([len(a[2]) for a in batch_unifiability])
                unifiability_labels = np.array([l[0] for l in batch_unifiability])

                # Well formedness
                wf_input: np.ndarray = np.array([np.array(a[1] + ((data_obj.max_word_len - len(a[1])) * [0]))
                                                 for a in batch_well_formed])
                wf_labels: np.ndarray = np.array([a[0] for a in batch_well_formed])
                wf_length: np.ndarray = np.array([len(a[1]) for a in batch_well_formed])

                # Alpha Equivalence
                alpha_a: np.ndarray = np.array([np.array(a[1] + ((data_obj.max_word_len - len(a[1])) * [0]))
                                                for a in batch_alpha_equiv])
                alpha_a_length = np.array([len(a[1]) for a in batch_alpha_equiv])
                alpha_b: np.ndarray = np.array([np.array(a[2] + ((data_obj.max_word_len - len(a[2])) * [0]))
                                                for a in batch_alpha_equiv])
                alpha_b_length = np.array([len(a[2]) for a in batch_alpha_equiv])
                alpha_equiv_labels = np.array([l[0] for l in batch_alpha_equiv])

                # #### FEED DICT and TRAINING ####

                # feeddict for already installed props
                feed_dict = {
                    small_enc.general_encoder_input: x,
                    small_enc.labels: y,
                    small_enc.encode_length: lengths,
                    small_enc.binary_subform_input_a: sub_fm_x,
                    small_enc.binary_subform_input_b: sub_fm_y,
                    small_enc.binary_subform_length_a: lengths_sub_fm_x,
                    small_enc.binary_subform_length_b: lengths_sub_fm_y,
                    small_enc.binary_subform_labels: sub_fm_labels,
                    small_enc.mp_input_form_a: mp_a,
                    small_enc.mp_input_form_b: mp_b,
                    small_enc.mp_second_input_length_a: lengths_a,
                    small_enc.mp_second_input_length_b: lengths_b,
                    small_enc.mp_class_labels: mp_labels,
                    small_enc.term_formula_input: tfs,
                    small_enc.term_formula_length: tf_lengths,
                    small_enc.term_formula_labels: tf_labels,
                    small_enc.unification_input_a: unifiability_a,
                    small_enc.unification_input_b: unifiability_b,
                    small_enc.unification_input_a_length: lengths_unifiability_a,
                    small_enc.unification_input_b_length: lengths_unifiability_b,
                    small_enc.unification_labels: unifiability_labels,
                    small_enc.hopefully_well_formed_input: wf_input,
                    small_enc.hopefully_well_formed_lengths: wf_length,
                    small_enc.hopefully_well_formed_labels: wf_labels,
                    small_enc.alpha_equi_input_a: alpha_a,
                    small_enc.alpha_equi_input_b: alpha_b,
                    small_enc.alpha_equi_length_a: alpha_a_length,
                    small_enc.alpha_equi_length_b: alpha_b_length,
                    small_enc.alpha_equi_labels: alpha_equiv_labels
                }

                run_metadata = tf.RunMetadata()

                _, loss_val, summary, acc, all_acc, mp_acc, ft_acc, uni_acc, bin_sub_acc, well_f_acc, alpha_equi_acc =\
                    session.run(
                        [
                            small_enc.optimizer,
                            small_enc.total_loss,
                            small_enc.merged,
                            small_enc.accuracy,
                            small_enc.all_corr_accuracy,
                            small_enc.binary_accuracy,  # mp
                            small_enc.term_formula_accuracy,
                            small_enc.uni_accuracy,
                            small_enc.bin_sub_accuracy,
                            small_enc.well_formed_accuracy,
                            small_enc.alpha_equi_accuracy
                        ], feed_dict=feed_dict, run_metadata=run_metadata)

                sum_epoch_loss += loss_val

                avg_epoch_mp_acc += mp_acc
                avg_epoch_ft_acc += ft_acc
                avg_epoch_uni_acc += uni_acc
                avg_epoch_wellf_acc += well_f_acc
                avg_epoch_alpha_acc += alpha_equi_acc
                avg_epoch_binary_sub += bin_sub_acc

                if verbose:
                    print("Epoch {:>5}/{:<5}\tChunk: {:>5}/{: <5}\t loss: {: <20}\t acc: {: <20}\t all acc: {: <20}"
                          .format(str(e + 1), str(num_epoch), str(i + 1), str(len(batches_subform)), str(loss_val),
                                  str(acc), str(all_acc)))
                    print("Binary Subterm acc: {}".format(bin_sub_acc))
                    print("mp acc: {}".format(mp_acc))
                    print("Term vs Formula acc: {}".format(ft_acc))
                    print("Unifiability acc: {}".format(uni_acc))
                    print("Well-formedness acc: {}".format(well_f_acc))
                    print("Alpha Equivalence acc: {}".format(alpha_equi_acc))
                    print("\n")

                writer.add_summary(summary, i)

            print("\n")
            print('Average loss at epoch {:>5}/{:<5} : {}'.format(e + 1, num_epoch,
                                                                  sum_epoch_loss / num_batches))

            print("Binary Subterm acc: {}".format(avg_epoch_binary_sub / num_batches))
            print("mp acc: {}".format(avg_epoch_mp_acc / num_batches))
            print("Term vs Formula acc: {}".format(avg_epoch_ft_acc / num_batches))
            print("Unifiability acc: {}".format(avg_epoch_uni_acc / num_batches))
            print("Well-formedness acc: {}".format(avg_epoch_wellf_acc / num_batches))
            print("Alpha Equivalence acc: {}".format(avg_epoch_alpha_acc / num_batches))
            print("\n")
            sum_epoch_loss: float = 0.0
            avg_epoch_mp_acc: float = 0.0
            avg_epoch_ft_acc: float = 0.0
            avg_epoch_uni_acc: float = 0.0
            avg_epoch_wellf_acc: float = 0.0
            avg_epoch_alpha_acc: float = 0.0
            avg_epoch_binary_sub: float = 0.0

        saver.save(session, os.path.join(log_dir, 'model.ckpt'))

    writer.close()
    data_obj.dump_training_formulas()


def calculate_thought_vectors(input_formulas: List[str], batch_size: int = 16, embedding_dim=128, num_tokens=80,
                              log_dir='models/logdir', encoding_projection=False, encoding_projection_dim=256):
    """
    Transforms string intputs into embedded vectors and returns them.
    """

    data_obj: TrainingFormulas = TrainingFormulas.load_training_formulas_strict(
        file=os.path.join(log_dir, 'training_formulas.pkl'))

    encoded_input: List[List[int]] = [data_obj.encoding_of_string(x)
                                      for x in input_formulas if len(x) <= data_obj.max_word_len]

    batches: List[List[List[int]]] = list(split_into(encoded_input, len(encoded_input)//batch_size))

    g = tf.Graph()
    with g.as_default():
        # encoder
        small_enc = DeepPhiModel(embedding_dim, DeephiMode.DEFAULT,
                                 max_len=data_obj.max_word_len,
                                 num_tokens=num_tokens,
                                 encoding_projection=encoding_projection,
                                 encoding_projection_dim=encoding_projection_dim)

        small_enc.build_model()
        saver = tf.train.Saver()

    result = []
    with tf.Session(graph=g) as session:
        # restore last checkpoint (one could alos average the last few checkpoint weights)
        saver.restore(session, tf.train.latest_checkpoint(log_dir))

        for batch in tqdm.tqdm(batches):
            # for each input vector in input_formulas.
            inputs = np.array([np.array(x + ((data_obj.max_word_len - len(x)) * [0])) for x in batch])
            encode_length = np.array([len(en) for en in batch])
            feed_dict = {
                small_enc.general_encoder_input: inputs,
                small_enc.encode_length: encode_length
            }

            # this would be the most prominent vector
            thought_vecs = session.run(small_enc.thought_vectors_hidden, feed_dict=feed_dict)

            result.append(thought_vecs)

    return np.concatenate(result)
