PyTorch Data Preprocess
1 2
   | import torch from torchvision import datasets, transforms
   | 
 
 Import Error
1
   | ImportError: cannot import name 'PILLOW_VERSION' from 'PIL'
   | 
 
pillow 버전이 7.0.0 이상 일경우 Import 에러 나는 경우가 있다.
아래 처럼 pillow 버전을 내려주면 해결이 된다.
1
   | $ pip install pillow==6.2.2
   | 
 
 Data Loader 부르기
Pytorch는 DataLoader를 불러 model에 넣는다.
1 2 3 4 5 6 7 8 9 10
   | batch_size = 32
  train_loader = torch.utils.data.DataLoader(     datasets.MNIST('dataset/', train=True, download=True,                 transform=transforms.Compose([                   transforms.ToTensor(),                   transforms.Normalize(mean=(0.5,), std=(0.5,))                 ])),                 batch_size=batch_size,                 shuffle=True)
   | 
 
1 2 3 4 5 6 7 8 9 10
   | test_batch_size = 32
  test_loader = torch.utils.data.DataLoader(     datasets.MNIST('dataset', train=False,                 transform=transforms.Compose([                   transforms.ToTensor(),                   transforms.Normalize((0.5,), (0.5))                 ])),                 batch_size=test_batch_size,                 shuffle=True)
   | 
 
 첫번째 iteration에서 나오는 데이터 확인
1 2 3 4
   | images, labels = next(iter(train_loader)) image.shape, label.shape
  => torch.Size([32, 1, 28, 28]),  torch.Size([32])
   | 
 
 데이터 시각화
1 2 3
   | import numpy as np import matplotlib.pyplot as plt %matplotlib inline
   | 
 
1 2 3 4 5
   |  torch_image = torch.squeeze(images[0]) torch_image.shape
  => torch.Size([28, 28])
 
  | 
 
1 2 3 4
   | image = torch_image.numpy() image.shape
  => (28, 28)
   | 
 
1 2 3
   | plt.title(label) plt.imshow(image, 'gray') plt.show()
   | 
 
