PyTorch Data Preprocess
1 2
| import torch from torchvision import datasets, transforms
|
Import Error
1
| ImportError: cannot import name 'PILLOW_VERSION' from 'PIL'
|
pillow 버전이 7.0.0 이상 일경우 Import 에러 나는 경우가 있다.
아래 처럼 pillow 버전을 내려주면 해결이 된다.
1
| $ pip install pillow==6.2.2
|
Data Loader 부르기
Pytorch는 DataLoader를 불러 model에 넣는다.
1 2 3 4 5 6 7 8 9 10
| batch_size = 32
train_loader = torch.utils.data.DataLoader( datasets.MNIST('dataset/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,)) ])), batch_size=batch_size, shuffle=True)
|
1 2 3 4 5 6 7 8 9 10
| test_batch_size = 32
test_loader = torch.utils.data.DataLoader( datasets.MNIST('dataset', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5)) ])), batch_size=test_batch_size, shuffle=True)
|
첫번째 iteration에서 나오는 데이터 확인
1 2 3 4
| images, labels = next(iter(train_loader)) image.shape, label.shape
=> torch.Size([32, 1, 28, 28]), torch.Size([32])
|
데이터 시각화
1 2 3
| import numpy as np import matplotlib.pyplot as plt %matplotlib inline
|
1 2 3 4 5
| torch_image = torch.squeeze(images[0]) torch_image.shape
=> torch.Size([28, 28])
|
1 2 3 4
| image = torch_image.numpy() image.shape
=> (28, 28)
|
1 2 3
| plt.title(label) plt.imshow(image, 'gray') plt.show()
|