[PyTorch] 데이터 불러오기

PyTorch Data Preprocess

1
2
import torch
from torchvision import datasets, transforms

Import Error

1
ImportError: cannot import name 'PILLOW_VERSION' from 'PIL'

pillow 버전이 7.0.0 이상 일경우 Import 에러 나는 경우가 있다.
아래 처럼 pillow 버전을 내려주면 해결이 된다.

1
$ pip install pillow==6.2.2

Data Loader 부르기

Pytorch는 DataLoader를 불러 model에 넣는다.

1
2
3
4
5
6
7
8
9
10
batch_size = 32

train_loader = torch.utils.data.DataLoader(
datasets.MNIST('dataset/', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])),
batch_size=batch_size,
shuffle=True)
1
2
3
4
5
6
7
8
9
10
test_batch_size = 32

test_loader = torch.utils.data.DataLoader(
datasets.MNIST('dataset', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5))
])),
batch_size=test_batch_size,
shuffle=True)

첫번째 iteration에서 나오는 데이터 확인

1
2
3
4
images, labels = next(iter(train_loader))
image.shape, label.shape

=> torch.Size([32, 1, 28, 28]), torch.Size([32])

데이터 시각화

1
2
3
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
1
2
3
4
5
# squeeze() 함수는 차원의 원소가 1인 차원을 없애준다.
torch_image = torch.squeeze(images[0])
torch_image.shape

=> torch.Size([28, 28])
1
2
3
4
image = torch_image.numpy()
image.shape

=> (28, 28)
1
2
3
plt.title(label)
plt.imshow(image, 'gray')
plt.show()

Share