Skip to content

Commit 30aa61c

Browse files
authored
Hot Fix for Reaction Center Prediction (#116)
* Update * Update * Update * Update * Update * Update * Update * Update
1 parent 45ac277 commit 30aa61c

File tree

5 files changed

+33
-12
lines changed

5 files changed

+33
-12
lines changed

examples/reaction_prediction/rexgen_direct/README.md

+10-8
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,12 @@ You can then train a model on new datasets with
218218
python find_reaction_center_train.py --train-path X --val-path Y
219219
```
220220

221-
where `X`, `Y` are paths to the new training/validation as described above.
221+
where `X`, `Y` are paths to the new training/validation dataset as described above.
222222

223223
For evaluation,
224224

225225
```bash
226-
python find_reaction_center_eval.py --eval-path Z
226+
python find_reaction_center_eval.py --test-path Z
227227
```
228228

229229
where `Z` is the path to the new test set as described above.
@@ -334,20 +334,22 @@ python candidate_ranking_eval.py
334334
You can train a model on new datasets with
335335

336336
```bash
337-
python candidate_ranking_train.py --train-path train_valid_reactions.proc --val-path val_valid_reactions.proc -cmp X
337+
python candidate_ranking_train.py --train-path X --val-path Y -cmp Z
338338
```
339339

340-
where `X` is the path to a trained model for reaction center prediction. You can use our
341-
pre-trained model by not specifying `-cmp`.
340+
where `X`, `Y` are paths to the new training/validation dataset as in reaction center prediction. `Z` is
341+
the path to a trained model for reaction center prediction. You can use our pre-trained model by not specifying `-cmp`.
342342

343343
For evaluation,
344344

345345
```bash
346-
python candidate_ranking_eval.py --model-path X -cmp Y --eval-path test_valid_reactions.proc
346+
python candidate_ranking_eval.py --model-path X -cmp Y --test-path Z
347347
```
348348

349-
where `X` is the path to a trained model for candidate ranking and `Y` is the path to a trained model
350-
for reaction center prediction. As in training, you can use our pre-trained model by not specifying `-cmp`.
349+
where `X` is the path to a trained model for candidate ranking, `Y` is the path to a trained model
350+
for reaction center prediction, and `Z` is the path to the new test dataset as in reaction center prediction.
351+
You can use the pre-trained model for reaction center prediction by not specifying `-cmp` and use the pre-trained
352+
model for candidate ranking by not specifying `--model-path`.
351353

352354
### Common Issues
353355

examples/reaction_prediction/rexgen_direct/candidate_ranking_eval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def main(args, path_to_candidate_bonds):
2020
num_processes=args['num_processes'])
2121
else:
2222
test_set = WLNRankDataset(
23-
path_to_reaction_file=args['test_path'],
23+
path_to_reaction_file='test_valid_reactions.proc',
2424
candidate_bond_path=path_to_candidate_bonds['test'], mode='test',
2525
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
2626
num_processes=args['num_processes'])

examples/reaction_prediction/rexgen_direct/candidate_ranking_train.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def main(args, path_to_candidate_bonds):
2424
num_processes=args['num_processes'])
2525
else:
2626
train_set = WLNRankDataset(
27-
path_to_reaction_file=args['train_path'],
27+
path_to_reaction_file='train_valid_reactions.proc',
2828
candidate_bond_path=path_to_candidate_bonds['train'], mode='train',
2929
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_train'],
3030
num_processes=args['num_processes'])
@@ -36,7 +36,7 @@ def main(args, path_to_candidate_bonds):
3636
num_processes=args['num_processes'])
3737
else:
3838
val_set = WLNRankDataset(
39-
path_to_reaction_file=args['val_path'],
39+
path_to_reaction_file='val_valid_reactions.proc',
4040
candidate_bond_path=path_to_candidate_bonds['val'], mode='val',
4141
max_num_change_combos_per_reaction=args['max_num_change_combos_per_reaction_eval'],
4242
num_processes=args['num_processes'])
@@ -133,6 +133,14 @@ def main(args, path_to_candidate_bonds):
133133
t0 = time.time()
134134
model.train()
135135

136+
# Final results
137+
torch.save({'model_state_dict': model.state_dict()},
138+
args['result_path'] + '/model_final.pkl')
139+
prediction_summary = 'final\n' + candidate_ranking_eval(args, model, val_loader)
140+
print(prediction_summary)
141+
with open(args['result_path'] + '/val_eval.txt', 'a') as f:
142+
f.write(prediction_summary)
143+
136144
if __name__ == '__main__':
137145
from argparse import ArgumentParser
138146

examples/reaction_prediction/rexgen_direct/find_reaction_center_train.py

+10
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,16 @@ def main(rank, dev_id, args):
124124
model.train()
125125
synchronize(args['num_devices'])
126126

127+
# Final results
128+
if rank == 0:
129+
prediction_summary = 'final result ' + \
130+
reaction_center_final_eval(args, args['top_ks_val'], model, val_loader, easy=True)
131+
print(prediction_summary)
132+
with open(args['result_path'] + '/val_eval.txt', 'a') as f:
133+
f.write(prediction_summary)
134+
torch.save({'model_state_dict': model.state_dict()},
135+
args['result_path'] + '/model_final.pkl')
136+
127137
def run(rank, dev_id, args):
128138
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
129139
master_ip=args['master_ip'], master_port=args['master_port'])

examples/reaction_prediction/rexgen_direct/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,8 @@ def prepare_reaction_center(args, reaction_center_config):
549549
else:
550550
dataset = WLNCenterDataset(raw_file_path=args['{}_path'.format(subset)],
551551
mol_graph_path='{}.bin'.format(subset),
552-
num_processes=args['num_processes'])
552+
num_processes=args['num_processes'],
553+
reaction_validity_result_prefix=subset)
553554

554555
dataloader = DataLoader(dataset, batch_size=args['reaction_center_batch_size'],
555556
collate_fn=collate_center, shuffle=False)

0 commit comments

Comments
 (0)