-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon.py
27 lines (25 loc) · 1.04 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Dict, List
import numpy as np
from models import init_model, set_parameters
def test_model(model_name: str, dataset_name: str, parameters: List[np.ndarray], test_loader: DataLoader, DEVICE: torch.device) -> Dict[str, float]:
model = init_model(dataset_name, model_name)
set_parameters(model, parameters)
criterion = nn.CrossEntropyLoss(reduction="sum")
correct, total, loss = 0, 0, 0.0
model.eval()
model.to(DEVICE)
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = model(images)
labels = labels.squeeze()
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.detach(), 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss /= len(test_loader.dataset)
accuracy = correct/total
return {'test_loss': loss, 'test_acc': accuracy}