Skip to content

Commit

Permalink
feat: implemented MoViNet-Stream (Closes #2)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinwholtmon committed Feb 20, 2024
1 parent 8e8e231 commit fd67477
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
4 changes: 3 additions & 1 deletion experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def main():
model = MoViNetModule(
num_classes=classes_to_predict,
lr=learning_rate,
n_clips=5,
n_clip_frames=frames_per_clip / 5,
)

# define callbacks
Expand Down Expand Up @@ -137,7 +139,7 @@ def main():
else:
# Load checkpoint
print(f"Loading checkpoint: {checkpoint}")
model = model.load_from_checkpoint(checkpoint, model=backbone)
model = model.load_from_checkpoint(checkpoint)
trainer.test(model=model, dataloaders=data_module)


Expand Down
35 changes: 27 additions & 8 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ def __init__(
lr_min=1e-7,
momentum=0.9,
weight_decay=5e-4,
n_clips=5,
n_clip_frames=50,
):
super().__init__()
self.save_hyperparameters(ignore=["model"])

# Define the model
model = MoViNet(_C.MODEL.MoViNetA0, causal=False, pretrained=True)
model = MoViNet(_C.MODEL.MoViNetA0, causal=True, pretrained=True)
# Replace last layer to correspond to the number of classes
in_channels = model.classifier[-1].conv_1.conv3d.in_channels
in_channels = model.classifier[-1].conv_1.conv2d.in_channels
model.classifier[-1] = torch.nn.Conv3d(
in_channels, self.hparams.num_classes, (1, 1, 1)
)
Expand All @@ -33,6 +35,7 @@ def __init__(
self.criterion = torch.nn.CrossEntropyLoss(
label_smoothing=self.hparams.label_smoothing
)
self.automatic_optimization = False

metrics = MetricCollection(
{
Expand All @@ -58,13 +61,29 @@ def forward(self, x):
def _shared_step(
self, batch, prefix, batch_idx, metric, on_step=False, on_epoch=False
):
self.model.clean_activation_buffers()

# Predict
features, true_labels = batch
logits = self(features)
loss = self.criterion(logits, true_labels)
# predicted_labels = torch.argmax(logits, dim=1) # Pretty sure torchmetrics does this automatically/supports both formats

# Devide into subclips, and do backward pass for each clip
# frames_per_clip = n_clip_frames*n_clips
for j in range(self.hparams.n_clips):
clip = features[
:,
:,
int(self.hparams.n_clip_frames)
* j : int(self.hparams.n_clip_frames)
* (j + 1),
]
logits = self(clip)
loss = self.criterion(logits, true_labels)
if prefix == "train":
self.manual_backward(loss)
if prefix == "train":
optimz = self.optimizers()
optimz.step()
optimz.zero_grad()

# Clean
self.model.clean_activation_buffers()

# log metrics
metric_dict = {f"{prefix}_loss": loss}
Expand Down

0 comments on commit fd67477

Please sign in to comment.