-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathhparams.py
40 lines (33 loc) · 904 Bytes
/
hparams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from collections import OrderedDict
HP = OrderedDict(
# dataset
data=['lalonde'],
dataroot=['PATH_TO_YOUR_DATASETS'], # TODO: MODIFY THIS PATH LOCALLY
# saveroot=['save'],
# train=[True],
# eval=[True],
# overwrite_reload=[''],
# distribution of outcome (y)
dist=['SigmoidFlow'],
dist_args=[['ndim=10', 'base_distribution=normal'], ['ndim=5', 'base_distribution=uniform']],
atoms=[[0.0], []], # list of floats, or empty list
# architecture
n_hidden_layers=[1],
dim_h=[64],
activation=['ReLU'],
# training params
lr=[0.001],
batch_size=[64],
num_epochs=[10],
early_stop=[True],
ignore_w=[False],
grad_norm=['inf'],
w_transform=['Standardize'],
y_transform=['Normalize'],
train_prop=[0.5],
val_prop=[0.1],
test_prop=[0.4],
seed=[123],
# evaluation
num_univariate_tests=[100]
)