from Utils import *
import os
import random
import argparse
import numpy as np
import phi_embedding
import phi_embedding_lstm_fc
import phi_embedding_lstm
import phi_embedding_cnn_fc_weights
import phi_embedding_cnn_fc

from CONSTANTS import *

# Just overall structure here


# CONSTANT(s) THIS MAY CHANGE on DIFFERENT datasets
NUM_TOKENS = 80


def partition_data(source: str, dest: str = "../evaluation", nth=10):
    """
    :param source:
    :param dest:
    :param nth: move 1/th from source to dest
    :return:
    """
    if not os.path.exists(dest):
        os.makedirs(dest, 0o777)

    subfm_training = os.path.join(source, DataDirNames.SUBFORMULAS.value)
    mp_training = os.path.join(source, DataDirNames.MODUS_PONENS.value)
    term_form_training = os.path.join(source, DataDirNames.TERM_FORMULA.value)
    unifiability_training = os.path.join(source, DataDirNames.UNIFIABILITY.value)
    well_formed_training = os.path.join(source, DataDirNames.WELL_FORMEDNESS.value)
    alpha_training = os.path.join(source, DataDirNames.ALPHA_EQUIVALENCE.value)

    subfm_eval = os.path.join(dest, DataDirNames.SUBFORMULAS.value)
    mp_eval = os.path.join(dest, DataDirNames.MODUS_PONENS.value)
    term_form_eval = os.path.join(dest, DataDirNames.TERM_FORMULA.value)
    unifiability_eval = os.path.join(dest, DataDirNames.UNIFIABILITY.value)
    well_formed_eval = os.path.join(dest, DataDirNames.WELL_FORMEDNESS.value)
    alpha_eval = os.path.join(dest, DataDirNames.ALPHA_EQUIVALENCE.value)

    if not os.path.exists(subfm_eval):
        os.makedirs(subfm_eval)
    if not os.path.exists(mp_eval):
        os.makedirs(mp_eval)
    if not os.path.exists(term_form_eval):
        os.makedirs(term_form_eval)
    if not os.path.exists(unifiability_eval):
        os.makedirs(unifiability_eval)
    if not os.path.exists(well_formed_eval):
        os.makedirs(well_formed_eval)
    if not os.path.exists(alpha_eval):
        os.makedirs(alpha_eval)

    _tmp = list_files_from_directory(subfm_training)
    eval_files = random.sample(_tmp, int(float(len(_tmp)) / nth))
    for file in eval_files:
        os.rename(file, os.path.join(subfm_eval, os.path.basename(file)))

    _tmp = list_files_from_directory(mp_training)
    eval_files = random.sample(_tmp, int(float(len(_tmp)) / nth))
    for file in eval_files:
        os.rename(file, os.path.join(mp_eval, os.path.basename(file)))

    _tmp = list_files_from_directory(unifiability_training)
    eval_files = random.sample(_tmp, int(float(len(_tmp)) / nth))
    for file in eval_files:
        os.rename(file, os.path.join(unifiability_eval, os.path.basename(file)))

    _tmp = list_files_from_directory(alpha_training)
    eval_files = random.sample(_tmp, int(float(len(_tmp)) / nth))
    for file in eval_files:
        os.rename(file, os.path.join(alpha_eval, os.path.basename(file)))

    # ##### Term/Forms
    _tmp = list_files_from_directory(term_form_training)

    formulas_file = [s for s in _tmp if '_formulas' in os.path.basename(s)][0]
    formula_lines = file_to_string_list(formulas_file)
    random.shuffle(formula_lines)
    training_fms = formula_lines[int(float(len(formula_lines)) / nth):]
    eval_fms = formula_lines[:int(float(len(formula_lines)) / nth)]
    with open(os.path.join(term_form_training, os.path.basename(formulas_file)), "w+") as f:
        for tf in training_fms:
            f.write(tf.rstrip()+"\n")
    with open(os.path.join(term_form_eval, os.path.basename(formulas_file)), "w+") as f:
        for ef in eval_fms:
            f.write(ef.rstrip()+"\n")

    terms_file = [s for s in _tmp if '_terms' in s][0]
    terms_lines = file_to_string_list(terms_file)
    random.shuffle(terms_lines)
    training_terms = terms_lines[int(float(len(terms_lines)) / nth):]
    eval_terms = terms_lines[:int(float(len(terms_lines)) / nth)]
    with open(os.path.join(term_form_training, os.path.basename(terms_file)), "w+") as f:
        for tf in training_terms:
            f.write(tf.rstrip()+"\n")
    with open(os.path.join(term_form_eval, os.path.basename(terms_file)), "w+") as f:
        for ef in eval_terms:
            f.write(ef.rstrip()+"\n")

    # ##### Well formedness
    _tmp = list_files_from_directory(well_formed_training)

    try:
        wf_file = [s for s in _tmp if '_well_formed' in os.path.basename(s)]
    except IndexError:
        print("Error partitioning Well-formed data")
        return
    wf_file = wf_file[0]
    wf_lines = file_to_string_list(wf_file)
    random.shuffle(wf_lines)
    training_wf = wf_lines[int(float(len(wf_lines)) / nth):]
    eval_wf = wf_lines[:int(float(len(wf_lines)) / nth)]
    with open(os.path.join(well_formed_training, os.path.basename(wf_file)), "w+") as f:
        for tf_ in training_wf:
            f.write(tf_.rstrip()+"\n")
    with open(os.path.join(well_formed_eval, os.path.basename(wf_file)), "w+") as f:
        for ef in eval_wf:
            f.write(ef.rstrip()+"\n")

    try:
        not_wf_file = [s for s in _tmp if '_not_wellf' in s][0]
    except IndexError:
        return
    not_wf_lines = file_to_string_list(not_wf_file)
    random.shuffle(not_wf_lines)
    training_not_wf = not_wf_lines[int(float(len(not_wf_lines)) / nth):]
    eval_not_wf = not_wf_lines[:int(float(len(not_wf_lines)) / nth)]
    with open(os.path.join(well_formed_training, os.path.basename(not_wf_file)), "w+") as f:
        for tf_ in training_not_wf:
            f.write(tf_.rstrip()+"\n")
    with open(os.path.join(well_formed_eval, os.path.basename(not_wf_file)), "w+") as f:
        for ef in eval_not_wf:
            f.write(ef.rstrip()+"\n")


def union_data_into(source, destination):
    if not os.path.exists(source):
        return
    if not os.path.exists(destination):
        print("No destination")
        return

    # training is source
    subfm_training = os.path.join(source, DataDirNames.SUBFORMULAS.value)
    mp_training = os.path.join(source, DataDirNames.MODUS_PONENS.value)
    term_form_training = os.path.join(source, DataDirNames.TERM_FORMULA.value)
    unifiability_training = os.path.join(source, DataDirNames.UNIFIABILITY.value)
    well_formed_training = os.path.join(source, DataDirNames.WELL_FORMEDNESS.value)
    alpha_training = os.path.join(source, DataDirNames.ALPHA_EQUIVALENCE.value)

    # eval is destination
    subfm_eval = os.path.join(destination, DataDirNames.SUBFORMULAS.value)
    mp_eval = os.path.join(destination, DataDirNames.MODUS_PONENS.value)
    term_form_eval = os.path.join(destination, DataDirNames.TERM_FORMULA.value)
    unifiability_eval = os.path.join(destination, DataDirNames.UNIFIABILITY.value)
    well_formed_eval = os.path.join(destination, DataDirNames.WELL_FORMEDNESS.value)
    alpha_eval = os.path.join(destination, DataDirNames.ALPHA_EQUIVALENCE.value)

    if not os.path.exists(subfm_training):
        return
    if not os.path.exists(mp_training):
        return
    if not os.path.exists(unifiability_training):
        return
    if not os.path.exists(alpha_training):
        return

    move_files(subfm_training, subfm_eval)
    move_files(mp_training, mp_eval)
    move_files(unifiability_training, unifiability_eval)
    move_files(alpha_training, alpha_eval)

    # Move Terms/Formulas
    _tmp = list_files_from_directory(term_form_training)
    if len(_tmp) == 0:
        return

    formulas_file = [s for s in _tmp if '_formulas' in s]
    if len(formulas_file) > 0:
        formulas_file = formulas_file[0]
        with open(os.path.join(term_form_training, os.path.basename(formulas_file)), "r+") as srce:
            with open(os.path.join(term_form_eval, os.path.basename(formulas_file)), "a") as dest:
                for line in srce:
                    dest.write(line.rstrip()+"\n")
        os.remove(formulas_file)

    terms_file = [s for s in _tmp if '_terms' in s]
    if len(terms_file) > 0:
        terms_file = terms_file[0]
        with open(os.path.join(term_form_training, os.path.basename(terms_file)), "r+") as src:
            with open(os.path.join(term_form_eval, os.path.basename(terms_file)), "a") as dest:
                for line in src:
                    dest.write(line.rstrip()+"\n")
        os.remove(terms_file)

    # Move Well formed
    _tmp = list_files_from_directory(well_formed_training)
    if len(_tmp) == 0:
        return

    wf_files = [s for s in _tmp if '_well_formed' in s]
    if len(wf_files) > 0:
        wf_files = wf_files[0]
        with open(os.path.join(well_formed_training, os.path.basename(wf_files)), "r+") as srce:
            with open(os.path.join(well_formed_eval, os.path.basename(wf_files)), "a") as dest:
                for line in srce:
                    dest.write(line.rstrip()+"\n")
        os.remove(wf_files)

    not_wf_file = [s for s in _tmp if '_not_wellf' in s]
    if len(not_wf_file) > 0:
        not_wf_file = not_wf_file[0]
        with open(os.path.join(well_formed_training, os.path.basename(not_wf_file)), "r+") as src:
            with open(os.path.join(well_formed_eval, os.path.basename(not_wf_file)), "a") as dest:
                for line in src:
                    dest.write(line.rstrip()+"\n")
        os.remove(not_wf_file)


def start(train, evaluate, separate_data, put_back_data, num_epoch, batch_size, embedding_dim, max_len,
          learning_rate, decay_rate, decay_steps, vocab_size, log_dir, training_data, evaluation_data, test_data,
          verbose, max_eval_examples, embed, embedding_input, embedding_output, model_num, projection, projection_dim,
          split=10):

    print("\n\n#############    Starting DeepPhi...     ###########\n")

    if embed:
        strs = file_to_string_list(embedding_input)
        res = phi_embedding.calculate_thought_vectors(strs, batch_size, embedding_dim=embedding_dim, log_dir=log_dir)
        np.save(embedding_output, res)

    print("####### Parameters #######")
    print("epochs: {}".format(num_epoch))
    print("batch_size: {}".format(batch_size))
    print("embedding_dim: {}".format(embedding_dim))
    print("max_len: {}".format(max_len))
    print("vocab_size: {}".format(vocab_size))
    print("learning_rate: {}".format(learning_rate))
    print("decay_rate: {}".format(decay_rate))
    print("decay_steps: {}".format(decay_steps))
    print("log_dir: {}".format(log_dir))
    print("training_data: {}".format(training_data))
    print("evaluation_data: {}".format(evaluation_data))
    print("test_data: {}".format(test_data))
    print("\n")

    if separate_data:
        # ### DATA Training/Evaluation partition
        print("Splitting Data...")
        # for cleaning potentially previously separated data
        union_data_into(evaluation_data, training_data)
        union_data_into(test_data, training_data)

        print("Generating evaluation data...")
        partition_data(training_data, dest=evaluation_data, nth=split)
        print("Generating test data...")
        # partition_data(evaluation_data, dest=test_data, nth=2)

    if model_num == 1:
        # ### TRAINING
        if train:
            print("\nTraining CNN FC(1)...\n")
            phi_embedding_cnn_fc.train(training_data, num_epoch=num_epoch, batch_size=batch_size,
                                       embedding_dim=embedding_dim, num_tokens=NUM_TOKENS,
                                       max_len=max_len, learning_rate=learning_rate, decay_rate=decay_rate,
                                       decay_steps=decay_steps, vocab_size=vocab_size, log_dir=log_dir, verbose=verbose,
                                       encoding_projection=projection, projection_dim=projection_dim)

        # ### EVALUATION
        if evaluate:
            print("\nEvaluating CNN FC(1)...\n")
            phi_embedding_cnn_fc.evaluation(evaluation_data, embedding_dim=embedding_dim, max_len=max_len,
                                            learning_rate=learning_rate, decay_rate=decay_rate, decay_steps=decay_steps,
                                            vocab_size=vocab_size, log_dir=log_dir, verbose=verbose,
                                            num_tokens=NUM_TOKENS, max_examples_per_prop=max_eval_examples,
                                            encoding_projection=projection, projection_dim=projection_dim)
    elif model_num == 2:
        # ### TRAINING
        if train:
            print("\nTraining CNN FC Weights(2)...\n")
            phi_embedding_cnn_fc_weights.train(training_data, num_epoch=num_epoch, batch_size=batch_size,
                                               embedding_dim=embedding_dim, num_tokens=NUM_TOKENS,
                                               max_len=max_len, learning_rate=learning_rate, decay_rate=decay_rate,
                                               decay_steps=decay_steps, vocab_size=vocab_size, log_dir=log_dir,
                                               verbose=verbose, encoding_projection=projection,
                                               projection_dim=projection_dim)

        # ### EVALUATION
        if evaluate:
            print("\nEvaluating CNN FC Weights(2)...\n")
            phi_embedding_cnn_fc_weights.evaluation(evaluation_data, embedding_dim=embedding_dim, max_len=max_len,
                                                    learning_rate=learning_rate, decay_rate=decay_rate,
                                                    decay_steps=decay_steps, num_tokens=NUM_TOKENS,
                                                    vocab_size=vocab_size, log_dir=log_dir, verbose=verbose,
                                                    max_examples_per_prop=max_eval_examples,
                                                    encoding_projection=projection, projection_dim=projection_dim)

    elif model_num == 3:
        # ### TRAINING
        if train:
            print("\nTraining LSTM(3)...\n")
            phi_embedding_lstm.train(training_data, num_epoch=num_epoch, batch_size=batch_size,
                                     embedding_dim=embedding_dim, num_tokens=NUM_TOKENS,
                                     max_len=max_len, learning_rate=learning_rate, decay_rate=decay_rate,
                                     decay_steps=decay_steps, vocab_size=vocab_size, log_dir=log_dir,
                                     verbose=verbose, encoding_projection=projection, projection_dim=projection_dim)

        # ### EVALUATION
        if evaluate:
            print("\nEvaluating LSTM(3)...\n")
            phi_embedding_lstm.evaluation(evaluation_data, embedding_dim=embedding_dim, max_len=max_len,
                                          learning_rate=learning_rate, decay_rate=decay_rate, decay_steps=decay_steps,
                                          vocab_size=vocab_size, log_dir=log_dir, verbose=verbose,
                                          max_examples_per_prop=max_eval_examples, num_tokens=NUM_TOKENS,
                                          encoding_projection=projection,
                                          projection_dim=projection_dim)
    elif model_num == 4:
        # ### TRAINING
        if train:
            print("\nTraining LSTM FC(4)...\n")
            phi_embedding_lstm_fc.train(training_data, num_epoch=num_epoch, batch_size=batch_size,
                                        embedding_dim=embedding_dim, num_tokens=NUM_TOKENS,
                                        max_len=max_len, learning_rate=learning_rate, decay_rate=decay_rate,
                                        decay_steps=decay_steps, vocab_size=vocab_size, log_dir=log_dir,
                                        verbose=verbose, encoding_projection=projection, projection_dim=projection_dim)

        # ### EVALUATION
        if evaluate:
            print("\nEvaluating LSTM FC(4)...\n")
            phi_embedding_lstm_fc.evaluation(evaluation_data, embedding_dim=embedding_dim, max_len=max_len,
                                             learning_rate=learning_rate, decay_rate=decay_rate,
                                             decay_steps=decay_steps, num_tokens=NUM_TOKENS,
                                             vocab_size=vocab_size, log_dir=log_dir, verbose=verbose,
                                             max_examples_per_prop=max_eval_examples,
                                             encoding_projection=projection, projection_dim=projection_dim)
    else:
        # ### TRAINING
        if train:
            print("\nTraining CNN(0)...\n")
            phi_embedding.train(training_data, num_epoch=num_epoch, batch_size=batch_size,
                                embedding_dim=embedding_dim, num_tokens=NUM_TOKENS,
                                max_len=max_len, learning_rate=learning_rate, decay_rate=decay_rate,
                                decay_steps=decay_steps, vocab_size=vocab_size, log_dir=log_dir, verbose=verbose,
                                encoding_projection=projection, projection_dim=projection_dim)

        # ### EVALUATION
        if evaluate:
            print("\nEvaluating CNN(0)...\n")
            phi_embedding.evaluation(evaluation_data, embedding_dim=embedding_dim, max_len=max_len,
                                     learning_rate=learning_rate, decay_rate=decay_rate, decay_steps=decay_steps,
                                     vocab_size=vocab_size, log_dir=log_dir, verbose=verbose,
                                     max_examples_per_prop=max_eval_examples, num_tokens=NUM_TOKENS,
                                     encoding_projection=projection, projection_dim=projection_dim)

    if put_back_data:
        # ### PUT Evaluation data back into dataset
        print("Putting evaluation data back...")
        union_data_into(evaluation_data, training_data)
        union_data_into(test_data, training_data)


def arg_2_bool(v):
    if v.lower() in ['yes', 'true', 't', 'T', 'y', '1']:
        return True
    elif v.lower() in ['no', 'false', 'f', 'n', '0']:
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def main():
    print("STARTING")

    parser = argparse.ArgumentParser(description='DeepPhi - Formula Encoding')

    parser.add_argument('-train', action="store_true", default=False, help='Starts training.')
    parser.add_argument('-evaluate', action="store_true", default=False, help='Starts evaluation.')
    parser.add_argument('-separate-data', dest='separate_data', type=arg_2_bool, default=True,
                        help='If data should be separated into training/evaluation data.')
    parser.add_argument('-merge-data', dest='merge_data', type=arg_2_bool, default=True,
                        help='If separated data should be merged again')
    parser.add_argument('-split', help='Nth part is test data', action="store", type=int,
                        default=default_train_ev_split)

    parser.add_argument('-verbose', dest='verbose', action="store_true", default=False,
                        help='Verbose output')

    parser.add_argument('-max-eval-examples', dest='max_eval_examples', action="store", type=int,
                        default=default_eval_examples,
                        help="Determines the maximum number of examples to chose from evaluation data.")

    parser.add_argument('-epochs', action="store", type=int, default=default_num_epoch)
    parser.add_argument('-batch-size', dest='batch_size', action="store", type=int,
                        default=default_batch_size)
    parser.add_argument('-embedding-dim', dest='embedding_dim', action="store", type=int,
                        default=default_embedding_dim)
    parser.add_argument('-projection', action="store_true", default=False, help='Starts training.')
    parser.add_argument('-projection-dim', dest='projection_dim', action="store", type=int,
                        default=default_projection_dim)

    # This has no effect
    # parser.add_argument('-num-lstm-layers', dest='num_lstm_layers', action="store", type=int,
    #                    default=default_num_lstm_layers,
    #                    help="Deprecated. This has no effect")

    parser.add_argument('-num-multilabels', dest='vocab_size', action="store", type=int, default=default_vocab_size,
                        help='Number of whole formulas saved in dictionary (=labels in multilabel classification)')
    parser.add_argument('-max-len', dest='max_len', action="store", type=int, default=default_max_len,
                        help='Max length of input formulas.')
    parser.add_argument('-log-dir', dest='log_dir', action="store", type=str, default=default_log_dir,
                        help='Directory for logging.')
    parser.add_argument('-training-data', dest='training_data', action="store", type=str,
                        default=default_training_data, help='Where training data is located.')
    parser.add_argument('-evaluation-data', dest='evaluation_data', action="store", type=str,
                        default=default_evaluation_data, help='Where evaluation data is/will be located.')
    parser.add_argument('-test-data', dest='test_data', action="store", type=str, default=default_test_data,
                        help='Where test data is/will be located.')

    parser.add_argument('-embed', action="store_true", default=False, help='Embeds strings.')
    parser.add_argument('-embedding-input', dest='embedding_input', action="store", type=str,
                        help='Where data is for embedding.')
    parser.add_argument('-embedding-output', dest='embedding_output', action="store", type=str,
                        help='Where embeddings are saved to.')

    parser.add_argument('-model-num', dest="model_num", action="store", type=int, default=0)

    results = parser.parse_args()

    start(train=results.train, evaluate=results.evaluate, separate_data=results.separate_data,
          put_back_data=results.merge_data, num_epoch=results.epochs, batch_size=results.batch_size,
          embedding_dim=results.embedding_dim, vocab_size=results.vocab_size, max_len=results.max_len,
          log_dir=results.log_dir, training_data=results.training_data, test_data=results.test_data,
          evaluation_data=results.evaluation_data, learning_rate=default_learning_rate,
          decay_rate=default_decay_rate, decay_steps=default_decay_steps, verbose=results.verbose,
          max_eval_examples=results.max_eval_examples, embed=results.embed,
          embedding_input=results.embedding_input, embedding_output=results.embedding_output,
          model_num=results.model_num, projection=results.projection, projection_dim=results.projection_dim,
          split=results.split)

    print("Done!")


# #######  START ############
# easiest way to run: python3 DeepPhy.py -train -evaluate -training-data ../small_training_data

# working: python3 DeepPhi.py -train -evaluate -batch-size 2 -embedding-dim 128 -max-len 256 -num-multilabels
# 1024 -training-data ../small_training_data
if __name__ == '__main__':
    main()
