
import numpy as np
from random import shuffle, sample

from data_generation.formula_data_generation import TrainingFormulas, custom_tptp_file_to_string_list
from typing import Tuple
from Utils import *
import os


def sub_formula_pre_processing(data_obj, files: List[str]):
    training_list: List[Tuple] = []
    for file in files:
        pairs_list: List[Tuple] = data_obj.generate_int_training_pairs_subterm([file])
        labels = [x for _, x in pairs_list]
        label_vec = data_obj.max_vocab_size * [0]
        for l in labels:
            label_vec[l] = 1
        t = data_obj.word_encoding_of_key(pairs_list[0][0])
        if len(t) < data_obj.max_word_len:
            n_t = t + [0] * (data_obj.max_word_len - len(t))
            training_list.append((n_t, label_vec))
    return training_list


def sub_formula_binary_pre_processing_with_encoding(data_obj: TrainingFormulas, files: List[str], num_negatives=3):
    concatenated = {}
    all_: List[str] = []
    for file in files:
        encoded_list: List[str] = data_obj.string_list_of_file(file)
        all_ = all_ + encoded_list
        concatenated[encoded_list[0]] = encoded_list

    shuffle(all_)

    positive_examples: List[Tuple] = []
    for x, y in concatenated.items():
        positive_examples = positive_examples + [(x, z) for z in y]

    negative_examples: List[Tuple] = []
    for x, y in positive_examples:
        ys = sample(all_, num_negatives)
        # Take care of reflexivity (if x subformula y )
        if y in concatenated:
            xs = filter((lambda a: a in concatenated[y]), ys)
        else:
            xs = ys
        negative_examples = negative_examples + [(x, z) for z in xs]

    return [(data_obj.encoding_of_string(x), data_obj.encoding_of_string(y)) for x, y in positive_examples], \
           [(data_obj.encoding_of_string(x), data_obj.encoding_of_string(y)) for x, y in negative_examples]


def sub_term_pre_eval_processing(data_obj: TrainingFormulas, files: List[str]):
    """
    Due to (bad) design decisions
    """
    training_list: List[Tuple[List[int], List[int]]] = []
    for file in files:
        stmts = custom_tptp_file_to_string_list(file)
        if len(stmts[0]) > data_obj.max_word_len:  # maybe handle this better with <unk> or something
            continue
        label_vec = data_obj.max_vocab_size * [0]
        for l in stmts:
            index = data_obj.key_of_word(l)
            label_vec[index] = 1
        t = data_obj.encoding_of_string(stmts[0])
        n_t = t + [0] * (data_obj.max_word_len - len(t))
        training_list.append((n_t, label_vec))
    return training_list


# shamelessly plugged this one - should be changed to retrieve ALL logits
# set to 1 after rounding -> still nice tho
def get_label_using_logits(logits, vocabulary_index2word_label, top_number=5):
    index_list = np.argsort(logits)[-top_number:]
    index_list = index_list[::-1]
    label_list = []
    for index in index_list:
        label = vocabulary_index2word_label[index]
        label_list.append(label)
    return label_list


def get_label_indices(rounded_logits):
    label_indices = [i for i, e in enumerate(rounded_logits) if e == 1]  # maybe off by one
    return label_indices


def get_label_indices_np(array: np.ndarray):
    return np.where(array == 1)[0]


def modus_ponens_pre_processing(data_obj, files: List[str]):
    # positive saples
    _positive = [s for s in files if '_true_' in os.path.basename(s)]
    mp_positive_pairs = data_obj.binary_data_set_to_encoded_pairs(_positive)
    # negative samples
    _negative = [s for s in files if '_false_' in os.path.basename(s)]
    mp_negative_pairs = data_obj.binary_data_set_to_encoded_pairs(_negative)
    return mp_positive_pairs, mp_negative_pairs


def formula_term_class_prep(data_obj: TrainingFormulas, files: List[str]):
    # get formulas
    _formulas = [s for s in files if '_formulas' in os.path.basename(s)]
    fms = data_obj.files_based_data_set_extraction(_formulas)

    # get terms
    _terms = [s for s in files if '_terms' in s]
    terms = []
    for f in _terms:
        with open(f, 'r') as fp:
            for line in fp:
                terms.append(line.strip())
    return [data_obj.encoding_of_string(f) for f in fms], [data_obj.encoding_of_string(t) for t in terms]


def unifiability_classifier_data_generation(data_obj: TrainingFormulas, files: List[str]) \
        -> Tuple[List[Tuple[List[int], List[int]]], List[Tuple[List[int], List[int]]]]:
    # positive samples
    unif = [s for s in files if (('_unifiable.p' in s or 'is_unifiable' in s) and '_not_unifiable.p' not in s)]
    unifiable = data_obj.binary_set_free_formulas_encoding(unif)
    # negative samples
    nunif = [s for s in files if '_not_unifiable.p' in s]
    not_unifiable = data_obj.binary_set_free_formulas_encoding(nunif)
    return unifiable, not_unifiable


def well_formedness_pre_processing_with_encoding(data_obj: TrainingFormulas, files: List[str]) \
        -> Tuple[List[List[int]], List[List[int]]]:
    # positive samples

    pos_files = [f for f in files if 'well_formed' in os.path.basename(f)]
    positive_examples_enc: List[List[int]] = []
    for f in pos_files:
        positive_examples_enc = positive_examples_enc + data_obj.encoded_string_list_of_file(f)

    neg_files = [f for f in files if 'not_wellf' in f]
    negative_examples_enc: List[List[int]] = data_obj.file_with_free_formulas_to_encoding(neg_files)
    return positive_examples_enc, negative_examples_enc


def alpha_equivalence_pre_processing_with_encoding(data_obj: TrainingFormulas, files: List[str]) \
        -> Tuple[List[Tuple[List[int], List[int]]], List[Tuple[List[int], List[int]]]]:
    # positive samples

    pos_files = [f for f in files if 'alpha_equivalent_pos' in os.path.basename(f)]
    positive_examples_enc: List[Tuple[List[int], List[int]]] = data_obj.binary_data_set_to_encoded_pairs(pos_files)

    neg_files = [f for f in files if 'alpha_not_equiv_neg' in os.path.basename(f)]
    negative_examples_enc: List[Tuple[List[int], List[int]]] = data_obj.binary_data_set_to_encoded_pairs(neg_files)

    return positive_examples_enc, negative_examples_enc
