Skip to content

Commit

Permalink
reset metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Angryrou committed Jun 23, 2024
1 parent fc5dd1b commit a0cd26e
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion udao/model/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,22 @@ def predict_step(
features, _ = batch
return self.model(features)

def _reset_metrics(self) -> None:
for objective in self.objectives:
cast(Metric, self.metrics[objective]).reset()

def on_validation_epoch_start(self) -> None:
self._shared_epoch_end("train")
self._reset_metrics()

def validation_step(self, batch: Tuple[Any, th.Tensor], batch_idx: int) -> None:
self._shared_step(batch, "val")

def on_validation_epoch_end(self) -> None:
self._shared_epoch_end("val")

def on_test_epoch_start(self) -> None:
self._reset_metrics()

def test_step(self, batch: Tuple[Any, th.Tensor], batch_idx: int) -> None:
self._shared_step(batch, "test")

Expand Down

0 comments on commit a0cd26e

Please sign in to comment.