@@ -24,7 +24,7 @@ def main(args, path_to_candidate_bonds):
24
24
num_processes = args ['num_processes' ])
25
25
else :
26
26
train_set = WLNRankDataset (
27
- path_to_reaction_file = args [ 'train_path' ] ,
27
+ path_to_reaction_file = 'train_valid_reactions.proc' ,
28
28
candidate_bond_path = path_to_candidate_bonds ['train' ], mode = 'train' ,
29
29
max_num_change_combos_per_reaction = args ['max_num_change_combos_per_reaction_train' ],
30
30
num_processes = args ['num_processes' ])
@@ -36,7 +36,7 @@ def main(args, path_to_candidate_bonds):
36
36
num_processes = args ['num_processes' ])
37
37
else :
38
38
val_set = WLNRankDataset (
39
- path_to_reaction_file = args [ 'val_path' ] ,
39
+ path_to_reaction_file = 'val_valid_reactions.proc' ,
40
40
candidate_bond_path = path_to_candidate_bonds ['val' ], mode = 'val' ,
41
41
max_num_change_combos_per_reaction = args ['max_num_change_combos_per_reaction_eval' ],
42
42
num_processes = args ['num_processes' ])
@@ -133,6 +133,14 @@ def main(args, path_to_candidate_bonds):
133
133
t0 = time .time ()
134
134
model .train ()
135
135
136
+ # Final results
137
+ torch .save ({'model_state_dict' : model .state_dict ()},
138
+ args ['result_path' ] + '/model_final.pkl' )
139
+ prediction_summary = 'final\n ' + candidate_ranking_eval (args , model , val_loader )
140
+ print (prediction_summary )
141
+ with open (args ['result_path' ] + '/val_eval.txt' , 'a' ) as f :
142
+ f .write (prediction_summary )
143
+
136
144
if __name__ == '__main__' :
137
145
from argparse import ArgumentParser
138
146
0 commit comments