-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatautils.py
94 lines (70 loc) · 2.57 KB
/
datautils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
'''
I have to understand if we are able to make the MNIST dataset fit in memory.
If it does not fit in memory I should find a way to be able to make the same fitting procedure
as we wanted to do
'''
import torch
from torchvision import transforms
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt
import imageio
def get_binarized_MNIST(split = 'Train', flatten = True, random_seed=0):
'''
Returning the binarized MNIST dataset.
In addition to that we can return the label by using trick from
tweet from Alemi
:return: mnist data, mnist label
'''
# I start by loading the dataset
torch.manual_seed(random_seed)
np.random.seed(random_seed)
ims, labels = np.split(imageio.imread("https://i.imgur.com/j0SOfRW.png")[..., :3].ravel(), [-70000])
ims = np.unpackbits(ims).reshape((-1, 28, 28))
ims, labels = [np.split(y, [50000, 60000]) for y in (ims, labels)]
if split.lower() == 'train':
if flatten:
return ims[0].reshape(-1, 784), labels[0]
else:
return ims[0], labels[0]
elif split.lower() == 'valid':
if flatten:
return ims[1].reshape(-1, 784), labels[1]
else:
return ims[1], labels[1]
else:
if flatten:
return ims[2].reshape(-1,784), labels[2]
else:
return ims[2], labels[2]
def get_bernoulli_MNIST(random_seed = 3, verbose = False):
'''
Function that returns the dynamic Bernoulli MNIST dataset
'''
torch.manual_seed(random_seed)
np.random.seed(random_seed)
flatten_bernoulli = lambda x: x.view(-1).bernoulli()
MNIST_dataset= MNIST('data/', train=True, transform=transforms.ToTensor(), download=False)
data = []
labels = []
for i in range(len(MNIST_dataset)):
labels.append(MNIST_dataset[i][1])
fig = MNIST_dataset[i][0]
# todo: by doing this we are binarizing the dataset once, and then it's stati
## so we are not changing the images during training --> maybe I should do this
## in a way that it is still dynamic
data.append(flatten_bernoulli(fig).numpy())
data = np.array(data)
labels = np.array(labels)
if verbose:
print('Show an example in the dataset')
plt.imshow(data[0].reshape(28, 28))
plt.show()
print('Dataset size')
print(data.shape)
print(labels.shape)
return data, labels
# def get_dataloader_fully_observed(data, labels):
#
# if __name__ == "__main__":
# data, labels = get_bernoulli_MNIST(random_seed=12, verbose=True)