from data_generation.formula_data_generation import *
from phi_embedding_cnn_fc import calculate_thought_vectors

from random import shuffle
from Utils import *
import argparse
from CONSTANTS import *
import phi_embedding
import phi_embedding_lstm_fc
import phi_embedding_lstm
import phi_embedding_cnn_fc_weights
import phi_embedding_cnn_fc
import numpy as np
import sys

np.set_printoptions(threshold=sys.maxsize)


def main(results, calc_thoughts):

    fms: List[str] = file_to_string_list(results.src)

    skips = calc_thoughts(fms, batch_size=results.batch_size, embedding_dim=results.embedding_dim,
                                      log_dir=results.log_dir,
                                      encoding_projection=results.projection,
                                      encoding_projection_dim=results.projection_dim)

    with open(results.dest, "w+") as f:
        f.write("# {}\n".format(results.comment_string))
        for (x, y) in zip(skips, fms):
            s = ","
            f.write("({}, {})\n".format(np.array_str(x), y))


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='DeepPhi - Formula Encoding')
    parser.add_argument('-model-num', dest='model_num', action="store", type=int, required=True)
    parser.add_argument('-batch-size', dest='batch_size', action="store", type=int,
                        default=default_batch_size)
    parser.add_argument('-embedding-dim', dest='embedding_dim', action="store", type=int,
                        default=default_embedding_dim)
    parser.add_argument('-projection', action="store_true", default=False, help='Starts training.')
    parser.add_argument('-projection-dim', dest='projection_dim', action="store", type=int,
                        default=default_projection_dim)

    # This has no effect
    # parser.add_argument('-num-lstm-layers', dest='num_lstm_layers', action="store", type=int,
    #                    default=default_num_lstm_layers,
    #                    help="Deprecated. This has no effect")

    parser.add_argument('-num-multilabels', dest='vocab_size', action="store", type=int, default=default_vocab_size,
                        help='Number of whole formulas saved in dictionary (=labels in multilabel classification)')
    parser.add_argument('-max-len', dest='max_len', action="store", type=int, default=default_max_len,
                        help='Max length of input formulas.')
    parser.add_argument('-log-dir', dest='log_dir', action="store", type=str, required=True,
                        help='Directory for logging.')

    parser.add_argument('-src', dest='src', action="store", type=str, required=True,
                        help='Source File.')
    parser.add_argument('-dest', dest='dest', action="store", type=str, required=True,
                        help='Destination file.')
    parser.add_argument('-comment', dest='comment_string', action="store", type=str, required=False,default="",
                        help='Comment in top of output file')

    results = parser.parse_args()

    model_num = results.model_num
    calc_thoughts_fun = None

    if model_num == 1:
        calc_thoughts_fun = phi_embedding_cnn_fc.calculate_thought_vectors
    elif model_num == 2:
        calc_thoughts_fun = phi_embedding_cnn_fc_weights.calculate_thought_vectors
    elif model_num == 3:
        calc_thoughts_fun = phi_embedding_lstm.calculate_thought_vectors
    elif model_num == 4:
        calc_thoughts_fun = phi_embedding_lstm_fc.calculate_thought_vectors
    else:
        calc_thoughts_fun = phi_embedding.calculate_thought_vectors

    main(results, calc_thoughts_fun)
