Official implementation for the ICML 2023 paper: Which Invariance Should We Transfer: A Causal Minimax Learning Approach.
causal-learn, graphical_models, networkx, r-base, ggm, pytorch
To install ggm, activate your R terminal and use install.packages('ggm')
Generate synthetic data (or directly use the data provided under ./data/simulation/
):
python ./simulation/data.py
Run the training:
python ./simulation/main.py
Visualization
./simulation/draw.ipynb
Download the ADNI dataset here
Preprocess and partition of heterogeneous environment according to age:
python ./adni/extract.py
python ./adni/partition.py
Run the training:
python ./adni/main.py
Preprocess the data (or directly use the data provided under ./data/impc/
):
python ./impc/extract.py
Run the training:
python ./impc/main.py
Command:
python ./causal_discovery/causal_discovery.py -p path_to_data_folder -f data_filename
Example:
python ./causal_discovery/causal_discovery.py -p ./data/simulation/ -f 134581011151617.csv
Explanation:
this command uses the CD-NOD algorithm for heterogeneous causal discovery, it generates the followings under the cache
folder:
-
auggraph.gml
: augment causal graph over${Y,\mathbf{X}_S,\mathbf{X}_M,E}$ . -
stablegraph.gml
: stable causal graph over${Y,\mathbf{X}_S}$ . -
graphparse.json
: local components such as$\mathbf{X}_M,\mathbf{X}_M^0,\mathbf{W}$ , ect. -
cit_cache.json
: intermediate results of the conditional independence (CI) test.
Command:
python ./eqcls_search/eqcls_search.py -f stable_dag_gml_file
Example:
python ./eqcls_search/eqcls_search.py -f ./cache/stablegraph.gml
Explanation:
this command implements the Alg.2 in our paper to search for equivalence classes, it generates the followings under the cache
folder:
-
eqclses.json
: all equivalence classes -
eqsubsets.json
: the$N_G$ subsets we need to search
Command:
python ./findoptset/main.py -p graphparser_file -e eqsubsets_file -d path_to_data_folder
Example:
python ./findoptset/main.py -p ./cache/graphparse.json -e ./cache/eqsubsets.json -d ./data/simulation/
Explanation:
this command implements Alg.1 in our paper to find the optimal subset (invariant predictor), it generates a record.json
under the ./findopeset/
folder, which records the estimated and ground-truth worst-case risk for each subset.