This is the GitHub page for the paper Semantically consistent part discovery for fine-grained recognition.
First, download the datasets here:
- CelebA: Use the Google Drive link, download img/img_celeba.7z, and unzip.
- CUB: Get CUB_200_2011.tgz.
- PartImageNet: Download the PartImageNet_OOD version.
For the default folder structure, clone the git repo and extract the datasets into their respective folders in /datasets/
(extract the CelebA dataset in a new folder /celeba/unaligned/
). So, the default folder structure is:
├── celeba
├── cub
├── datasets
│ ├── celeba
│ │ └── unaligned
│ ├── cub
│ │ ├── CUB_200_2011
│ └── partimagenet
│ ├── test
│ ├── train
│ └── val
└── partimagenet
Next, build the conda environment using the environment.yml file.
Finally, to prepare the PartImageNet dataset, there is one more step. When we first downloaded the datasets, the PartImageNet_OOD version of the data was the only existing version. However, in this version, the sets of classes in 'test', 'train', and 'val' were disjoint. Thus, we created two subsets of the 'train' dataset: a training subset and a testing subset. Simply run in datasets/partimagenet to prepare the dataset.
The argument parser in
has the following parameters:
determines the name under which the file containing the parameters will be saved, and the folder in which the results will be saved. The parameters file will be saved as ./[dataset]/[model_name].pt
, and the results will be saved in ../results_[model_name]
is the folder containing the the dataset directories, i.e. the folder containing /celeba
, /cub
, and /partimagenet
determines which dataset to use. Choose celeba, cub, or partimagenet.
is the number of parts the model should use. In the paper we trained the model with 4, 8, and 16 parts for CelebA and CUB, and 8, 25, and 50 parts for PartImageNet.
is the learning rate. We used 1e-4
for all datasets.
is the batch size, we used 15 for CUB, and 20 for CelebA and PartImageNet.
is the resolution to which the input images will be cropped. First, the short edge of the raw image is resized to image_size
, and then the resized image is cropped to (image_size x image_size)
is the number of epochs to run. The default for CelebA is 15, the default for CUB is 28, and the default for PartImageNet is 20.
is only used when evaluating the model, or when continuing the training process with a previously saved model. It should be equal to the --model_name
parameter used when the model was trained. Used in conjunction with --warm_start True
determines whether attention maps are saved in the validation
function in
is used when you do not want to train the model, and instead you only wish to evaluate using a set of parameters. When you use this option, you should also use --warm_start
and --pretrained_model_name
See below for some examples for each dataset.
python --model_name celeb_8parts --data_root ./datasets --dataset celeba --num_parts 8 --batch_size 20 --image_size 256 --epochs 15
python --model_name celeb_8parts --data_path ../datasets/celeba --num_parts 8 --pretrained_model_name celeb_8parts --image_size 256 --warm_start True --only_test True
python --model_path ./celeba/ --data_root ./datasets --num_parts 8 --image_size 256
python --model_name cub_8parts --data_root ./datasets --dataset cub --num_parts 8 --batch_size 16 --image_size 448 --epochs 28
python --model_name cub_8parts --data_root ./datasets --dataset cub --num_parts 8 --batch_size 16 --image_size 448 --pretrained_model_path ./cub/ --save_figures --only_test
ppython --model_path ./cub/ --data_root ./datasets --num_parts 8 --image_size 448
python --model_name partimagenet_25parts --data_root ./datasets --dataset partimagenet --num_parts 25 --batch_size 20 --image_size 224 --epochs 20
python --model_name partimagenet_25parts --data_root ./datasets --dataset partimagenet --num_parts 25 --batch_size 20 --image_size 224 --pretrained_model_path ./partimagenet/ --save_figures --only_test
python --model_path ./partimagenet/ --data_root ./datasets --num_parts 25 --image_size 224
If you find any bugs, please either send me an e-mail or open an issue on GitHub.