microsoft / ContextualSP

Multiple paper open-source codes of the Microsoft Research Asia DKI group

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Reproducing LEMON

ardauzunoglu opened this issue · comments

Hello,

Firstly, thanks for the great work! After reading "LEMON: Language-Based Environment Manipulation via Execution-Guided Pre-training", I wanted to reproduce the results on Propara. However I obtained a very low accuracy score.

Python version: 3.9.0
Fairseq version: 0.12.2

Here are the steps I followed:

1 - Cloning the repository.
2 - Downloading the data/BART models.
3 - Preprocessing propara for both the pretraining and finetuning.
4 - Pretraining with BART-large.
5 - Finetuning the pretrained BART-large model.

For preprocessing, I used the preprocess_pretrain.sh and preprocess_finetune.sh files. For pretraining and finetuning, I used pretrain.sh and finetune.sh files without any parameter change. These steps leaded up to the following performance:

Correct / Total : 19 / 368, Denotation Accuracy : 0.052
path: bart_large_finetuned/checkpoint_best.pt, stage: valid, 1utts: 0.017, 3utts: 0.018, 5utts: 0.0
path: bart_large_finetuned/checkpoint_best.pt, stage: test, 1utts: 0.041, 3utts: 0.068, 5utts: 0.068

I would really appreciate your help for reproducing the results.
Thanks in advance.

commented

@ardauzunoglu Thanks for reaching out! I think the problem may be that you do not correct load the pre-trained model weights from bart-large since the experimental results of bart-large seem to be wired.

You may check with the following steps:

  • Have you downloaded the bart-large checkpoint folder into local folder?
  • Do you specify the folder path in the pre-training and fine-tuning script (i.e., replace BART_MODEL_PATH with the model weight path such as /path/to/model.pt)

I downloaded BART-large from here. I got the model.pt, dict.txt, and NOTE files after unzipping.

Then, I used the following commands:
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
to get encoder.json and vocab.bpe files.

This is the command I use:
python3 lemon/run_model_pretrain.py train
--dataset-dir lemon_data/pretraining_corpus/propara_catall/bin_large
--exp-dir bart_large_pretrained
--model-path bart.large
--model-arch bart_large
--total-num-update 10000
--max-tokens 1800
--gradient-accumulation 8
--warmup-steps 1500

where bart.large is the folder that includes the aforementioned files.

Might the problem be that I did not specify bart.large/model.pt?
Thank you so much for your quick response, really appreciate it.

commented

@ardauzunoglu Yes the problem should be the path. You should specify the model-path as the bart.large/model.pt. Please try again and it should work well.

Sorry for any misleading information. I will try to make it clear later.

Due to the computational limits, I couldn't pretrain and finetune BART-large. Instead, I did pretrain and finetune BART-base on Propara in both levels of training. At the end of the fine-tuning, checkpoints are evaluated with denotation accuracy. As paper indicates that denotation accuracy is used for Alchemy, Scene, and Tangrams, how can I evaluate the replicated model on Propara?

Thanks for your help.
(By the way, the highest denotation accuracy score is 0.261. Do you think this score is understandable?)

commented

@ardauzunoglu I’m not sure about the performance. Have you tried direct fine-tuning based on bart-base? How about the performance?

commented

@ardauzunoglu Thanks for your attention. The evaluation metrics of Propara are different from Alchemy, Scene, Tangrams, which can be divided into sentence-level and document-level. Both of them should run the following evaluation scripts.
The sentence-level evaluation can be found in :
https://github.com/allenai/propara/blob/master/propara/evaluation/evalQA.py
And the document-level evaluation can be found in:
https://github.com/microsoft/ContextualSP/tree/master/lemon/propara_evaluator

Before running these scripts, please fill this file according to the prediction results by our model:
https://github.com/microsoft/ContextualSP/blob/master/lemon/propara_evaluator/predictions.tsv
For each participant(like "bones"), you should choose the required actions according to the initial state and the goal state.

Sorry for any unclear information. If there are still questions, please feel free to reach out.

I could finally pretrain and finetune BART-large. I also evaluated both pretrained & finetuned BART-large and BART-base models with document-level evaluation. Here are the results:

BART-base:
Question Avg. Precision Avg. Recall Avg. F1
Inputs 0.762 0.494 0.599
Outputs 0.855 0.496 0.628
Conversions 0.304 0.348 0.325
Moves 0.209 0.431 0.281
Overall Precision 0.532
Overall Recall 0.442
Overall F1 0.483

BART-large:
Question Avg. Precision Avg. Recall Avg. F1
Inputs 0.731 0.540 0.621
Outputs 0.852 0.524 0.649
Conversions 0.291 0.348 0.317
Moves 0.249 0.384 0.302
Overall Precision 0.531
Overall Recall 0.449
Overall F1 0.486

Here is the code I used to obtain predictions.tsv:

import pandas as pd
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Finetune a transformers model on a multiple choice task")
    parser.add_argument(
        "--id_file",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--eval_file",
        type=str,
        default=None,
    )
    args = parser.parse_args()
    return args

def main():
    args = parse_args()
    id_file = open(args.id_file, "r", encoding="utf-8").readlines()
    eval_file = open(args.eval_file, "r", encoding="utf-8").read()

    predictions_df = pd.DataFrame(columns=["ID", "STEP", "ENTITY", "AFFECT", "STATE_T", "STATE_T+1"])

    eval_scores = [elm.split("\t")[0] for elm in eval_file.split("\n")[1:-1]]
    eval_predictions = [elm.split("\t")[1] for elm in eval_file.split("\n")[1:-1]]
    eval_goldens = [elm.split("\t")[2] for elm in eval_file.split("\n")[1:-1]]
    eval_sources = [elm.split("\t")[3] for elm in eval_file.split("\n")[1:-1]]
    eval_ids = [int(elm.split("\t")[4]) for elm in eval_file.split("\n")[1:-1]]

    for i in range(len(id_file)-1):
            try:
                entities = eval_sources[eval_ids.index(i)].split("state :")[0].replace("col :", "").split("|")
                if int(id_file[i].split("-")[1]) == 1:
                    entity_states = eval_sources[eval_ids.index(i)].split("state :")[1].split("SEP")[0].replace("state :", "").split("|")
                    entity_states_at_next_time_step = eval_predictions[eval_ids.index(i)].split("state :")[1].split("SEP")[0].replace("state :", "").split("|")
                else:
                    entity_states = eval_predictions[eval_ids.index(i)].split("state :")[1].split("SEP")[0].replace("state :", "").split("|")
                    entity_states_at_next_time_step = eval_predictions[eval_ids.index(i+1)].split("state :")[1].split("SEP")[0].replace("state :", "").split("|")

                entities = [entity.strip() for entity in entities]
                entity_states = [state.strip() for state in entity_states]
                entity_states_at_next_time_step = [state.strip() for state in entity_states_at_next_time_step]

                for entity in entities:
                    if (entity_states_at_next_time_step[entities.index(entity)] == "-") and (entity_states[entities.index(entity)] != "-"):
                        affect = "DESTROY"
                    elif (entity_states[entities.index(entity)] == "-") and (entity_states_at_next_time_step[entities.index(entity)] != "-"):
                        affect = "CREATE"
                    elif entity_states[entities.index(entity)] == entity_states_at_next_time_step[entities.index(entity)]:
                        affect = "NONE"
                    else:
                        affect = "MOVE"
                    predictions_df = predictions_df.append({"ID":id_file[i].split("-")[0], "STEP":int(id_file[i].split("-")[1]), "ENTITY":entity, "AFFECT":affect, "STATE_T":entity_states[entities.index(entity)], "STATE_T+1":entity_states_at_next_time_step[entities.index(entity)]}, ignore_index=True)
            except:
                print(i)
    predictions_df.to_csv("predictions.csv", index=False, sep="\t")

if __name__ == "__main__":
    main()

For both the pretraining and finetuning, I used the same hyperparameters shared in the pretrain.sh and finetune.sh files in this repository.

Thanks for your continuous help.

commented

@ardauzunoglu The convert script can be found here.

import collections
import argparse

def get_col_states(input_str, cols):
    states = input_str.replace('state : ', '').strip().split(' | ')
    result = {}
    for i in range(len(cols)):
        result[cols[i]] = states[i]
    return result


def get_col_states_start(input_str):
    col_and_state = input_str.split(' state : ')
    cols = col_and_state[0].replace('col : ', '').split(' | ')
    states = col_and_state[1].split(' | ')
    states[-1] = states[-1].split(' SEP ')[0]
    result = {}
    for i in range(len(cols)):
        result[cols[i]] = states[i]
    return result, cols


def get_action(location_before, location_after):
    if location_before == location_after:
        return "NONE",location_before, location_after
    if location_before == '-' and location_after != '-':
        return "CREATE",location_before, location_after
    if location_after == '-' and location_before != '-':
        return "DESTROY",location_before, location_after
    if location_before != '-' and location_after != '-':
        return "MOVE",location_before, location_after


def process(id_path, generate_valid_path, dummy_path, if_answer=False):
    target_idx = 1
    if if_answer:
        target_idx = 2
    error_num = 0
    id_file = open(id_path, 'r', encoding='utf8')
    pre = open(generate_valid_path, 'r', encoding='utf8')
    out = open(dummy_path, 'w', encoding='utf8')

    linenum_to_colandstate = {}
    pre_lines = pre.readlines()[1:]

    for line in pre_lines:
        elements = line[:-1].split('\t')
        line_id = int(elements[-1])
        col_and_state = elements
        linenum_to_colandstate[line_id] = col_and_state

    current_case = -1  # id of current processing case
    pre_states = {}

    id_lines = id_file.readlines()

    step_num = 0  # memory the step num of each case
    # action_matrix = {}  # init the action matrix(start)
    action_matrix = collections.OrderedDict()
    for line_id, case_id in enumerate(id_lines):

        case_id, step_id = case_id[:-1].split('-')  # '4-1' -> [4, 1]

        if case_id != current_case:
            # get a new case
            # write the current case to the file
            for key in action_matrix.keys():
                for step_idx in range(step_num):
                    try:
                        line_out = str(current_case) + '\t' + str(step_idx + 1) + '\t' + key + '\t' + action_matrix[key][
                            step_idx][0] + '\t' + action_matrix[key][step_idx][1] + '\t' + action_matrix[key][step_idx][2] + '\t'
                        out.write(line_out + '\n')
                    except:
                        line_out = str(current_case) + '\t' + str(step_idx + 1) + '\t' + key + '\t' + 'NONE' + '\t' + '-' + '\t' + '-' + '\t'
                        out.write(line_out + '\n')

            action_matrix = {}
            step_num = 0

            current_case = case_id
            start_col_and_state = linenum_to_colandstate[line_id][-2]
            pre_states, cols = get_col_states_start(start_col_and_state)  # get the init state

            for key in pre_states.keys():
                action_matrix[key] = []  # init the action matrix

        step_num += 1
        col_and_state = linenum_to_colandstate[line_id][target_idx]  # get the first state (after the first action)
        current_states = get_col_states(col_and_state, cols)
        if current_states.keys() != pre_states.keys():
            error_num += 1

        for col in current_states.keys():
            try:
                action_matrix[col].append((get_action(pre_states[col], current_states[col])))
            except:
                current_states_keys = list(current_states.keys())
                pre_states_keys = list(pre_states.keys())

                error_idx = current_states_keys.index(col)
                pre_col = pre_states_keys[error_idx]
                error_col = current_states_keys[error_idx]
                right_col = list(action_matrix.keys())[error_idx]

                error_action = (get_action(pre_states[pre_col], current_states[error_col]))
                action_matrix[right_col].append(error_action)
                # print(action_matrix)
                # print("pre_col", pre_col, " |||error_col", error_col)
        pre_states = current_states

    print('error_num', error_num)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--id_file", type=str, default='./lemon_data/dataset/propara/test.id', help="dataset prefix")
    parser.add_argument("--generate_file", type=str, default='./resources/lemon_propara_large/generate-test.txt.eval', help="dataset prefix")
    parser.add_argument("--output_predictions_file", type=str, default='./dummy-predictions.tsv', help="dataset prefix")
    parser.add_argument("--output_answers_file", type=str, default='./answers.tsv', help="dataset prefix")
    args = parser.parse_args()

    process(args.id_file, args.generate_file, args.output_predictions_file)
    process(args.id_file, args.generate_file, args.output_answers_file, if_answer=True)

After running this scripts, you can get the prediction file and answer file. Then run the following command to evaluate it:

python propara_evaluator/aristo-leaderboard/propara/evaluator/evaluator.py -p dummy-predictions.tsv -a answers.tsv

This will get the expected document-level performance. If there are still question to reproduce the result, please reach out and I'll deal with in time~

After using the converter you shared, I could obtain results similar to those of the paper. Here they are:

BART-base:

=================================================
Question     Avg. Precision  Avg. Recall  Avg. F1
-------------------------------------------------
Inputs                0.903        0.772    0.832
Outputs               0.863        0.811    0.836
Conversions           0.566        0.500    0.531
Moves                 0.404        0.560    0.469
-------------------------------------------------
Overall Precision 0.684                          
Overall Recall    0.661                          
Overall F1        0.672                          
=================================================

BART-large

=================================================
Question     Avg. Precision  Avg. Recall  Avg. F1
-------------------------------------------------
Inputs                0.869        0.873    0.871
Outputs               0.953        0.892    0.921
Conversions           0.554        0.535    0.544
Moves                 0.533        0.556    0.544
-------------------------------------------------
Overall Precision 0.727                          
Overall Recall    0.714                          
Overall F1        0.721                          
=================================================

Thank you so much for your help. I also tried to evaluate these models with sentence-level evaluation. I could not figure out how to convert model predictions to the format used in sample predictions. I also got the following error when I ran the example usage with the gold labels and model predictions shared here:

Traceback (most recent call last):
  File "sent_evaluation/evalQA.py", line 488, in <module>
    main()
  File "sent_evaluation/evalQA.py", line 434, in main
    tp, fp, tn, fn = Q(labels, predictions)
  File "sent_evaluation/evalQA.py", line 241, in Q1
    pred_creation_step = findCreationStep(predictions[pid][participant].values())
KeyError: 'fossil'

Do you have any insights into the sentence-level evaluation?

Thanks again, I really appreciate it.

commented

@ardauzunoglu We just use the official evalQA.py file to perform sentence-level evaluation without any modification. If you see the "KeyError", it may due to that the participants in the dataset may be a little different from the groundtruth. So you may need to map them manually. This shouldn't be too much of a hassle as there should only be a handful of participants that are misaligned. Hope it may help~

I see, but I was wondering if you had a converter to convert the model predictions to the format used in the sample predictions. For example, the following row is taken from the model predictions:

37 1 fossils NONE - -

and the following is the counterpart of the above-mentioned row from the sample predictions:

37 1 fossil null null nowhere nowhere

Similarly, the rows from the gold_labels.test.tsv have some columns that are absent in answers.tsv:

For example, the following row is taken from answers.tsv:

38 5 rock CREATE - around the body

and the following is the counterpart of the above-mentioned row from gold_labels.test.tsv:

38	5	rock	before	null	1.0
38	5	rock	after	sandy or wet place	0.2	around the body	0.6	unk	0.2

Thank you.

commented

@ardauzunoglu I found this script and hope it will be helpful.

def switch_unk_null(input):
    switch_dic = {
        '-' : 'null',
        '?' : 'unk'
    }
    try:
        return switch_dic[input]
    except:
        return input

def deal(input, output):
    input_lines = input.readlines()
    for line in input_lines:

        elements = line[:-1].split('\t')
        case_id = elements[0]
        step_id = elements[1]
        participant = elements[2]
        try:
            before = switch_unk_null(elements[4])
            after = switch_unk_null(elements[5])
        except:
            print(elements)
            exit()
        
        line = case_id + '\t' + step_id + '\t' + participant + '\t' + before + '\t' + after
        output.write(line + '\n')
    

def deal_id(input, output):
    cc = '-1'
    for line in input.readlines():
        c, s = line.split('-')
        if c != cc:
            output.write(c + '\n')
            cc = c
    


def process_cal_cat(src_dir):
    src = open(src_dir + '/dummy-predictions.tsv', 'r', encoding='utf8')
    out = open(src_dir + '/answers.tsv', 'r', encoding='utf8')
    src_out = open(src_dir + '/sample.model.test_predictions.tsv', 'w', encoding='utf8')
    golden_out = open(src_dir + '/gold_labels.test.tsv', 'w', encoding='utf8')

    id_input = open(src_dir + '/dev.id', 'r', encoding='utf8')
    id_out = open(src_dir + '/para_id.test.txt','w', encoding='utf8')

    deal(src, src_out)
    # deal(out, golden_out)
    deal_id(id_input, id_out)



process_cal_cat('PATH-TO-FILE')

After converting the predictions and answers with the script you shared and slightly modifying the evalQA.py, I could obtain the following evaluation results:

BART-large:

	Total	TP	FP	TN	FN	Accuracy	Precision	Recall	F1
Q1	990	515	10	429	36	95.35	98.10	93.47	95.72
Q2	565	434	90	0	41	76.81	82.82	91.37	86.89
Q3	448	177	148	0	123	39.51	54.46	59.00	56.64
Q4	990	389	111	382	108	77.88	77.80	78.27	78.03
Q5	509	231	170	0	108	45.38	57.61	68.14	62.43
Q6	399	154	102	0	143	38.60	60.16	51.85	55.70
Q7	990	207	143	518	122	73.23	59.14	62.92	60.97
Q8	2353	236	149	1691	277	52.56	61.30	46.00	52.56
Q9	400	141	115	0	144	35.25	55.08	49.47	52.13
Q10	472	165	130	0	177	34.96	55.93	48.25	51.81


Category	Accuracy Score
=========	=====
Cat-1		82.15
Cat-2		58.25
Cat-3		37.08
macro-avg	59.16
micro-avg	58.48

BART-base:

	Total	TP	FP	TN	FN	Accuracy	Precision	Recall	F1
Q1	990	512	8	431	39	95.25	98.46	92.92	95.61
Q2	565	436	90	0	39	77.17	82.89	91.79	87.11
Q3	448	187	109	0	152	41.74	63.18	55.16	58.90
Q4	990	334	116	377	163	71.82	74.22	67.20	70.54
Q5	509	160	180	0	169	31.43	47.06	48.63	47.83
Q6	399	169	78	0	152	42.36	68.42	52.65	59.51
Q7	990	206	199	462	123	67.47	50.86	62.61	56.13
Q8	2353	242	194	1646	271	51.00	55.50	47.17	51.00
Q9	400	165	93	0	142	41.25	63.95	53.75	58.41
Q10	472	184	110	0	178	38.98	62.59	50.83	56.10


Category	Accuracy Score
=========	=====
Cat-1		78.18
Cat-2		53.2
Cat-3		41.08
macro-avg	57.49
micro-avg	57.23

Thanks for your continuous help throughout my attempts.

Btw, here is the evalQA.py after the modification, just in case for future needs:

import sys, collections, pylev
from stemming.porter2 import stem

#--------------------------------------------------------------
# Author: Scott Wen-tau Yih
# Usage: evalQA.py para-ids gold-labels system-predictions
# example usage: python propara/eval/evalQA.py tests/fixtures/eval/para_id.test.txt tests/fixtures/eval/gold_labels.test.tsv tests/fixtures/eval/sample.model.test_predictions.tsv 
#--------------------------------------------------------------

# Data structure for Labels
'''
  PID -> [TurkerLabels]
  TurkerLabels = [TurkerQuestionLabel1, TurkerQuestionLabel2, ... ]  # labels on the same paragraph from the same Turker
  TurkerQuestionLabel -> (SID, Participant, Type, From, To)
'''
TurkerQuestionLabel = collections.namedtuple('TurkerQuestionLabel', 'sid participant event_type from_location to_location')


# Data structure for Predictions
'''
  PID -> Participant -> SID -> PredictionRecord
'''
PredictionRecord = collections.namedtuple('PredictionRecord', 'pid sid participant from_location to_location')

# Fixing tokenization mismatch while alinging participants
manual_participant_map = { 'alternating current':'alternate current', 'fixed nitrogen':'nitrogen',
                           'living things':'live thing', 'red giant star':'star', 'refrigerent liquid':'liquid',
                           'remains of living things':'remains of live thing',
                           "retina's rods and cones":"retina 's rod and cone" } #, 'seedling':'seed'}

#----------------------------------------------------------------------------------------------------------------

def compare_to_gold_labels(participants, system_participants):
    ret = []
    found = False
    for p in participants:
        p = p.lower()
        if p in system_participants:
            ret.append(p)
            continue
        for g in system_participants:
            if (pylev.levenshtein(p,g) < 3):
                #print (p, "===", g)
                ret.append(g)
                found = True
        if not found:
            if p in manual_participant_map:
                ret.append(manual_participant_map[p])
            #else:
            #    print("cannot find", p, system_participants)
    return ret

def preprocess_locations(locations):
    ret = []
    for loc in locations:
        if loc == '-':
            ret.append('null')
        elif loc == '?':
            ret.append('unk')
        else:
            ret.append(loc)
    return ret


def preprocess_question_label(sid, participant, event_type, from_location, to_location, system_participants=None):

    # check if there are multiple participants grouped together
    participants = [x.strip() for x in participant.split(';')]

    # check if there are multiple locations grouped together
    from_locations = preprocess_locations([x.strip() for x in from_location.split(';')])

    # check if there are multiple locations grouped together
    to_locations = preprocess_locations([x.strip() for x in to_location.split(';')])

    #print(participant, participants, system_participants)
    if system_participants != None: # check if the participants are in his list
        participants = compare_to_gold_labels(participants, system_participants)
        #print("legit_participants =", participants)

    #print(from_location, from_locations)
    #print(to_location, to_locations)

    return  [TurkerQuestionLabel(sid, p, event_type, floc, tloc) for p in participants
                                                                 for floc in from_locations
                                                                 for tloc in to_locations]

#----------------------------------------------------------------------------------------------------------------

'''
  Read the gold file containing all records where an entity undergoes some state-change: create/destroy/move.
'''
def readLabels(fnLab, selPid=None, gold_labels=None):
    fLab = open(fnLab)
    fLab.readline()    # skip header
    ret = {}
    TurkerLabels = []
    for ln in fLab:
        f = ln.rstrip().split('\t')
        if len(f) == 0 or len(f) == 1:
            if not selPid or pid in selPid:
                if pid not in ret:
                    ret[pid] = []
                ret[pid].append(TurkerLabels)
            TurkerLabels = []
        elif len(f) != 11:
            sys.stderr.write("Error: the number of fields in this line is irregular: " + ln)
            sys.exit(-1)
        else:
            if f[1] == '?': continue
            pid, sid, participant, event_type, from_location, to_location = int(f[0]), int(f[1]), f[3], f[4], f[5], f[6]

            if gold_labels and selPid and pid in selPid:
                #print("pid=", pid)
                try:
                    TurkerLabels += preprocess_question_label(sid, participant, event_type, from_location, to_location, gold_labels[pid].keys())
                except KeyError:
                    pass
            else:
                TurkerLabels += preprocess_question_label(sid, participant, event_type, from_location, to_location)

            #TurkerLabels += (TurkerQuestionLabel(sid, participant, event_type, from_location, to_location))
    return ret

#----------------------------------------------------------------------------------------------------------------

def readPredictions(fnPred):
    ret = {}

    for ln in open(fnPred):
        f = ln.rstrip().split('\t')
        pid, sid, participant, from_location, to_location = int(f[0]), int(f[1]), f[2], f[3], f[4]

        if pid not in ret:
            ret[pid] = {}
        dtPartPred = ret[pid]

        if participant not in dtPartPred:
            dtPartPred[participant] = {}

        dtPartPred[participant][sid] = PredictionRecord(pid, sid, participant, from_location, to_location)

    return ret

#----------------------------------------------------------------------------------------------------------------

def readGold(fn):
    # read the gold label

    dtPar = {}
    for ln in open(fn):
        f = ln.rstrip().split('\t')
        parId, sentId, participant, labels = int(f[0]), int(f[1]), f[2], f[3:]

        if sentId == 1:
            statusId = 0
        else:
            statusId = sentId

        if parId not in dtPar:
            dtPar[parId] = {}
        dtPartLab = dtPar[parId]
        if participant not in dtPartLab:
            dtPartLab[participant] = {statusId: labels}
        else:
            dtPartLab[participant][statusId] = labels
    return dtPar

#----------------------------------------------------------------------------------------------------------------

def findAllParticipants(lstTurkerLabels):
    setParticipants = set()
    for turkerLabels in lstTurkerLabels:
        for x in turkerLabels:
            setParticipants.add(x.participant)
    return setParticipants

def findCreationStep(prediction_records):
    steps = sorted(prediction_records, key=lambda x: x.sid)
    #print("steps:", steps)

    # first step
    if steps[0].from_location != 'null':    # not created (exists before the process)
        return -1

    for s in steps:
        if s.to_location != 'null':
            return s.sid
    return -1   # never exists

def findDestroyStep(prediction_records):
    steps = sorted(prediction_records, key=lambda x: x.sid, reverse=True)
    #print("steps:", steps)

    # last step
    if steps[0].to_location != 'null':  # not destroyed (exists after the process)
        return -1

    for s in steps:
        if s.from_location != 'null':
            return s.sid

    return -1   # never exists

def location_match(p_loc, g_loc):
    if p_loc == g_loc:
        return True

    p_string = ' %s ' % ' '.join([stem(x) for x in p_loc.lower().replace('"','').split()])
    g_string = ' %s ' % ' '.join([stem(x) for x in g_loc.lower().replace('"','').split()])

    if p_string in g_string:
        #print ("%s === %s" % (p_loc, g_loc))
        return True

    return False

def findMoveSteps(prediction_records):
    ret = []
    steps = sorted(prediction_records, key=lambda x: x.sid)
    # print(steps)
    for s in steps:
        if s.from_location != 'null' and s.to_location != 'null' and s.from_location != s.to_location:
            ret.append(s.sid)

    return ret

#----------------------------------------------------------------------------------------------------------------

# Q1: Is participant X created during the process?
def Q1(labels, predictions):
    tp = fp = tn = fn = 0.0
    for pid in labels:
        setParticipants = findAllParticipants(labels[pid])
        # find predictions
        be_created = {}
        for participant in setParticipants:
            pred_creation_step = findCreationStep(predictions[pid][participant].values())
            be_created[participant] = (pred_creation_step != -1)
            
        for turkerLabels in labels[pid]:
            # labeled as created participants
            lab_created_participants = [x.participant for x in turkerLabels if x.event_type == 'create']
            for participant in setParticipants:
                tp += int(be_created[participant] and (participant in lab_created_participants))
                fp += int(be_created[participant] and (participant not in lab_created_participants))
                tn += int(not be_created[participant] and (participant not in lab_created_participants))
                fn += int(not be_created[participant] and (participant in lab_created_participants))
    return tp,fp,tn,fn

# Q2: Participant X is created during the process. At which step is it created?
def Q2(labels, predictions):
    tp = fp = tn = fn = 0.0
    # find all created participants and their creation step
    for pid,lstTurkerLabels in labels.items():
        for turkerLabels in lstTurkerLabels:
            for x in [x for x in turkerLabels if x.event_type == 'create']:
                pred_creation_step = findCreationStep(predictions[pid][x.participant].values())
                tp += int(pred_creation_step != -1 and pred_creation_step == x.sid)
                fp += int(pred_creation_step != -1 and pred_creation_step != x.sid)
                fn += int(pred_creation_step == -1)
    return tp,fp,tn,fn

# Q3: Participant X is created at step Y, and the initial location is known. Where is the participant after it is created?
def Q3(labels, predictions):
    tp = fp = tn = fn = 0.0
    # find all created participants and their creation step
    for pid,lstTurkerLabels in labels.items():
        for turkerLabels in lstTurkerLabels:
            for x in [x for x in turkerLabels if x.event_type == 'create' and x.to_location != 'unk']:
                pred_loc = predictions[pid][x.participant][x.sid].to_location
                tp += int(pred_loc != 'null' and pred_loc != 'unk' and location_match(pred_loc, x.to_location))
                fp += int(pred_loc != 'null' and pred_loc != 'unk' and not location_match(pred_loc, x.to_location))
                fn += int(pred_loc == 'null' or pred_loc == 'unk')
    return tp, fp, tn, fn

#----------------------------------------------------------------------------------------------------------------

# Q4: Is participant X destroyed during the process?
def Q4(labels, predictions):
    tp = fp = tn = fn = 0.0
    for pid in labels:
        setParticipants = findAllParticipants(labels[pid])
        # find predictions
        be_destroyed = {}
        for participant in setParticipants:
            pred_destroy_step = findDestroyStep(predictions[pid][participant].values())
            be_destroyed[participant] = (pred_destroy_step != -1)
        for turkerLabels in labels[pid]:
            # labeled as destroyed participants
            lab_destroyed_participants = [x.participant for x in turkerLabels if x.event_type == 'destroy']
            for participant in setParticipants:
                tp += int(be_destroyed[participant] and (participant in lab_destroyed_participants))
                fp += int(be_destroyed[participant] and (participant not in lab_destroyed_participants))
                tn += int(not be_destroyed[participant] and (participant not in lab_destroyed_participants))
                fn += int(not be_destroyed[participant] and (participant in lab_destroyed_participants))
    return tp,fp,tn,fn

# Q5: Participant X is destroyed during the process. At which step is it destroyed?
def Q5(labels, predictions):
    tp = fp = tn = fn = 0.0
    # find all destroyed participants and their destroy step
    for pid, lstTurkerLabels in labels.items():
        for turkerLabels in lstTurkerLabels:
            for x in [x for x in turkerLabels if x.event_type == 'destroy']:
                    pred_destroy_step = findDestroyStep(predictions[pid][x.participant].values())
                    tp += int(pred_destroy_step != -1 and pred_destroy_step == x.sid)
                    fp += int(pred_destroy_step != -1 and pred_destroy_step != x.sid)
                    fn += int(pred_destroy_step == -1)
    return tp,fp,tn,fn

# Q6: Participant X is destroyed at step Y, and its location before destroyed is known. Where is the participant right before it is destroyed?
def Q6(labels, predictions):
    tp = fp = tn = fn = 0.0
    # find all created participants and their destroy step
    for pid,lstTurkerLabels in labels.items():
        for turkerLabels in lstTurkerLabels:
            for x in [x for x in turkerLabels if x.event_type == 'destroy' and x.from_location != 'unk']:
                pred_loc = predictions[pid][x.participant][x.sid].from_location
                tp += int(pred_loc != 'null' and pred_loc != 'unk' and location_match(pred_loc, x.from_location))
                fp += int(pred_loc != 'null' and pred_loc != 'unk' and not location_match(pred_loc, x.from_location))
                fn += int(pred_loc == 'null' or pred_loc == 'unk')
    return tp, fp, tn, fn

#----------------------------------------------------------------------------------------------------------------

# Q7 Does participant X move during the process?
def Q7(labels, predictions):
    tp = fp = tn = fn = 0.0
    for pid in labels:
        setParticipants = findAllParticipants(labels[pid])
        # find predictions
        be_moved = {}
        for participant in setParticipants:
            pred_move_steps = findMoveSteps(predictions[pid][participant].values())
            be_moved[participant] = (pred_move_steps != [])

        # print(be_moved)

        for turkerLabels in labels[pid]:
            lab_moved_participants = [x.participant for x in turkerLabels if x.event_type == 'move']
            for participant in setParticipants:
                tp += int(be_moved[participant] and (participant in lab_moved_participants))
                fp += int(be_moved[participant] and (participant not in lab_moved_participants))
                tn += int(not be_moved[participant] and (participant not in lab_moved_participants))
                fn += int(not be_moved[participant] and (participant in lab_moved_participants))

    return tp,fp,tn,fn

# Q8 Participant X moves during the process.  At which steps does it move?
def Q8(labels, predictions):
    tp = fp = tn = fn = 0.0
    for pid in labels:
        setParticipants = findAllParticipants(labels[pid])

        # find predictions
        pred_moved_steps = {}
        for participant in setParticipants:
            pred_moved_steps[participant] = findMoveSteps(predictions[pid][participant].values())

        for turkerLabels in labels[pid]:
            gold_moved_steps = {}
            for x in [x for x in turkerLabels if x.event_type == 'move']:
                if x.participant not in gold_moved_steps:
                    gold_moved_steps[x.participant] = []
                gold_moved_steps[x.participant].append(x.sid)

            for participant in gold_moved_steps:
                res = set_compare(pred_moved_steps[participant], gold_moved_steps[participant], len(predictions[pid][participant].keys()))
                tp += res[0]
                fp += res[1]
                tn += res[2]
                fn += res[3]
    return tp,fp,tn,fn

def set_compare(pred_steps, gold_steps, num_steps):
    setPred = set(pred_steps)
    setGold = set(gold_steps)
    tp = len(setPred.intersection(setGold))
    fp = len(setPred - setGold)
    fn = len(setGold - setPred)
    tn = num_steps - tp - fp - fn
    return (tp, fp, tn, fn)

# Q9 Participant X moves at step Y, and its location before step Y is known. What is its location before step Y?
def Q9(labels, predictions):
    tp = fp = tn = fn = 0.0
    for pid in labels:
        for turkerLabels in labels[pid]:
            for x in turkerLabels:
                if x.event_type == 'move' and x.from_location != 'unk':
                    pred_loc = predictions[pid][x.participant][x.sid].from_location
                    tp += int(pred_loc != 'null' and pred_loc != 'unk' and location_match(pred_loc, x.from_location))
                    fp += int(pred_loc != 'null' and pred_loc != 'unk' and not location_match(pred_loc, x.from_location))
                    fn += int(pred_loc == 'null' or pred_loc == 'unk')
    return tp,fp,tn,fn

# Q10 Participant X moves at step Y, and its location after step Y is known. What is its location after step Y?
def Q10(labels, predictions):
    tp = fp = tn = fn = 0.0
    for pid in labels:
        for turkerLabels in labels[pid]:
            for x in turkerLabels:
                if x.event_type == 'move' and x.to_location != 'unk':
                    pred_loc = predictions[pid][x.participant][x.sid].to_location
                    tp += int(pred_loc != 'null' and pred_loc != 'unk' and location_match(pred_loc, x.to_location))
                    fp += int(pred_loc != 'null' and pred_loc != 'unk' and not location_match(pred_loc, x.to_location))
                    fn += int(pred_loc == 'null' or pred_loc == 'unk')
    return tp,fp,tn,fn

#----------------------------------------------------------------------------------------------------------------

def main():
    if len(sys.argv) != 4:
        sys.stderr.write("Usage: evalQA.py para-ids gold-labels system-predictions\n")
        sys.exit(-1)
    paraIds = sys.argv[1]
    goldPred = sys.argv[2]
    fnPred = sys.argv[3]
    qid_to_score = {}

    selPid = set([int(x) for x in open(paraIds).readlines()])
    gold_labels = readGold(goldPred)
    labels = readLabels('sent_evaluation/all-moves.full-grid.tsv', selPid, gold_labels)
    predictions = readPredictions(fnPred)

    blHeader = True
    qid = 0
    for Q in [Q1, Q2, Q3, Q4, Q5, Q6, Q7, Q8, Q9, Q10]:
        qid += 1
        tp, fp, tn, fn = Q(labels, predictions)
        header,results_str, results = metrics(tp,fp,tn,fn,qid)
        if blHeader:
            print("\t%s" % header)
            blHeader = False
        print("Q%d\t%s" % (qid, results_str))
        qid_to_score[qid] = results[5]

    cat1_score = (qid_to_score[1] + qid_to_score[4] + qid_to_score[7]) / 3
    cat2_score = (qid_to_score[2] + qid_to_score[5] + qid_to_score[8]) / 3
    cat3_score = (qid_to_score[3] + qid_to_score[6] + qid_to_score[9] + qid_to_score[10]) / 4

    macro_avg = (cat1_score + cat2_score + cat3_score) / 3
    num_cat1_q = 750
    num_cat2_q = 601
    num_cat3_q = 823
    micro_avg = ((cat1_score * num_cat1_q) + (cat2_score * num_cat2_q) + (cat3_score * num_cat3_q)) / \
                (num_cat1_q + num_cat2_q + num_cat3_q)

    print("\n\nCategory\tAccuracy Score")
    print("=========\t=====")
    print(f"Cat-1\t\t{round(cat1_score,2)}")
    print(f"Cat-2\t\t{round(cat2_score,2)}")
    print(f"Cat-3\t\t{round(cat3_score,2)}")
    print(f"macro-avg\t{round(macro_avg,2)}")
    print(f"micro-avg\t{round(micro_avg,2)}")

def metrics(tp, fp, tn, fn, qid):
    if (tp+fp > 0):
        prec = tp/(tp+fp)
    else:	 	
        prec = 0.0
    if (tp+fn > 0):
        rec = tp/(tp+fn)
    else:		
        rec = 0.0
    if (prec + rec) != 0:
        f1 = 2 * prec * rec / (prec + rec)
    else:
        f1 = 0.0
    accuracy = (tp+tn) / (tp + fp + tn + fn)
    if qid == 8:
        accuracy = f1   # this is because Q8 can have multiple valid answers and F1 makes more sense here
    total = tp + fp + tn + fn

    header = '\t'.join(["Total", "TP", "FP", "TN", "FN", "Accuracy", "Precision", "Recall", "F1"])
    results = [total, tp, fp, tn, fn, accuracy*100, prec*100, rec*100, f1*100]
    results_str = "%d\t%d\t%d\t%d\t%d\t%.2f\t%.2f\t%.2f\t%.2f" % (total, tp, fp, tn, fn, accuracy*100, prec*100, rec*100, f1*100)
    return (header, results_str, results)

#----------------------------------------------------------------------------------------------------------------


if __name__ == "__main__":
    main()