-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathfinetune.py
81 lines (74 loc) · 2.7 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import pytorch_lightning as pl
from datasets import CrossViewiNATBirdsFineTune
from models import MAE, CVEMAEMeta, CVMMAEMeta, MoCoGeo
from torch.utils.data import random_split
import torch.nn.functional as F
from pytorch_lightning.callbacks import ModelCheckpoint
import numpy as np
from pytorch_lightning.loggers import WandbLogger
import json
import pandas as pd
from config import cfg
from utils import seed_everything
def finetune():
torch.cuda.empty_cache()
seed_everything()
logger = WandbLogger(project="BirdSAT", name=cfg.finetune.train.expt_name)
if cfg.finetune.train.dataset == "iNAT":
train_json = json.load(open("data/train_birds.json"))
train_labels = pd.read_csv("data/train_birds_labels.csv")
train_dataset = CrossViewiNATBirdsFineTune(train_json, train_labels)
val_json = json.load(open("data/val_birds.json"))
val_labels = pd.read_csv("data/val_birds_labels.csv")
val_dataset = CrossViewiNATBirdsFineTune(val_json, val_labels, val=True)
val_dataset, _ = random_split(
val_dataset, [int(0.2 * len(val_dataset)), int(0.8 * len(val_dataset))]
)
checkpoint = ModelCheckpoint(
monitor="val_loss",
dirpath="checkpoints",
filename="{cfg.pretrain.train.expt_name}-{epoch:02d}-{val_loss:.2f}",
mode="min",
)
if cfg.finetune.train.model_type == "MAE":
model = MAE.load_from_checkpoint(
cfg.finetune.train.ckpt,
train_dataset=train_dataset,
val_dataset=val_dataset,
)
model.setup_finetune()
elif cfg.finetune.train.model_type == "MOCOGEO":
model = MoCoGeo.load_from_checkpoint(
cfg.finetune.train.ckpt,
train_dataset=train_dataset,
val_dataset=val_dataset,
queue_dataset=None,
)
model.setup_finetune()
elif cfg.finetune.train.model_type == "CVEMAE":
model = CVEMAEMeta.load_from_checkpoint(
cfg.finetune.train.ckpt,
train_dataset=train_dataset,
val_dataset=val_dataset,
)
model.setup_finetune()
elif cfg.finetune.train.model_type == "CVMMAE":
model = CVMMAEMeta.load_from_checkpoint(
cfg.finetune.train.ckpt,
train_dataset=train_dataset,
val_dataset=val_dataset,
)
model.setup_finetune()
trainer = pl.Trainer(
accelerator="gpu",
devices=cfg.pretrain.train.devices,
strategy="ddp_find_unused_parameters_true",
max_epochs=cfg.pretrain.train.num_epochs,
num_nodes=1,
callbacks=[checkpoint],
logger=logger,
)
trainer.fit(model)
if __name__ == "__main__":
finetune()