방명록
- 인공 신경망 코드로 구현해서 다중 분류해보기 (2)2024년 03월 17일 12시 02분 55초에 업로드 된 글입니다.작성자: 재형이반응형
- 오늘 저녁 약속이 있는데 잘 하고 올게요~
- 오늘은 지난번에 했던 데이터 세트와 다른 데이터로 복습할겸 인공신경망을 구성해서 다시 학습을 해볼 것이다
- 추가로 학습이 완료된 모델을 저장하는 방법도 실습해볼 것이다
- 모델을 저장하는 방법에는 두가지가 있다
- 모델의 파라미터만 저장하는 방법
- 이 방법은 모델의 계층 구조를 알고 있어야만 나중에 다시 해당 모델을 사용할 수 있음
- 모델 전체를 저장하는 방법
- 모델의 파라미터만 저장하는 방법
- 모델을 저장하는 방법에는 두가지가 있다
- 사용할 데이터 세트 : Fashion MNIST
- 참고
복습
- TensorDataset과 DataLoader
- 입력 데이터를 쉽게 처리하고, 배치 단위로 잘러서 학습할 수 있게 도와주는 모듈
- Dataset : 학습시 사용하는 feature와 target의 pair로 이루어짐
- DataLoader: 학습 시 각 인스턴스에 쉽게 접근할 수 있도록 순회 가능한 객체(iterable)를 생성
- DataLoader가 하는 역할
- shuffling
- batch ...
- Device 설정
- 일반적으로 인공신경망의 학습은 (가능하다면) GPU를 사용하는 것이 바람직함
- GPU를 사용하여 학습을 진행하도록 명시적으로 작성 필요
- 연산 유형에 따라 GPU에서 수행이 불가능한 경우도 존재하는데, 그럴 경우도 마찬가지로 명시적으로 어떤 프로세서에서 연산을 수행해야하는지 코드로 작성해야함
- 신경망 생성
- torch.nn 패키지는 신경망 생성 및 학습 시 설정해야하는 다양한 기능을 제공
- 신경망을 nn.Module을 상속받아 정의
- __ init __(): 신경망에서 사용할 layer를 초기화하는 부분
- forward(): feed foward 연산 수행 시, 각 layer의 입출력이 어떻게 연결되는지를 지정
- Model compile
- 학습 시 필요한 정보들(loss function, optimizer)을 선언
- 일반적으로 loss와 optimizer는 아래와 같이 변수로 선언하고, 변수를 train/test 시 참고할 수 있도록 매개변수로 지정해줌
- Train
- 신경망의 학습과정을 별도의 함수로 구성하는 것이 일반적
- feed forward → loss → error back propagation → print(진행상황) 또는 로깅 → (반복)
- 신경망의 학습과정을 별도의 함수로 구성하는 것이 일반적
- Test
- 학습과정과 비슷하나 error back propagate하는 부분이 없음
- feed forward → loss → print(진행상황) 또는 로깅 → (반복)
- 학습과정과 비슷하나 error back propagate하는 부분이 없음
- Iteration
- 신경망 학습은 여러 epochs을 반복해서 수행하면서 모델을 구성하는 최적의 파라미터를 찾음
- 지정한 epochs 수만큼 학습과정과 평가과정을 반복하면서, 모델의 성능(loss, accuracy 등)을 체크함
Fashion MNIST Classifier
- Fashion MNIST 데이터셋을 사용하여 옷의 품목을 구분하는 분류기를 신경망을 사용하여 구현
[Step1] Load libraries & Datasets
- torch.nn : 신경망을 생성하기 위한 기본 재료들을 제공(Modules, Sequential, Layer, Activations, Loss, Dropout...)
- torchvision.datasets : torchvision.transforms를 사용해 변형이 가능한 형태, feature와 label을 반환
- torchvision.transforms
- ToTensor() : ndarray를 FloatTensor로 변환하고 이미지 픽셀 크기를 [0., 1.]범위로 조정(scale)
import numpy as np import matplotlib.pyplot as plt import torch from torch.utils.data import DataLoader from torch import nn from torchvision import datasets from torchvision.transforms import ToTensor # FashionMNIST 데이터 불러오기 training_data = datasets.FashionMNIST( root = 'data', train = True, download = True, transform = ToTensor(), ) test_data = datasets.FashionMNIST( root = 'data', train = False, download = True, transform = ToTensor(), )
- Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 26421880/26421880 [00:02<00:00, 9298111.12it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 29515/29515 [00:00<00:00, 163915.28it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 4422102/4422102 [00:01<00:00, 3134603.50it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 5148/5148 [00:00<00:00, 19700982.66it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
[Step2] Create DataLoader
train_dataloader = DataLoader(training_data, batch_size = 64, shuffle = True) test_dataloader = DataLoader(test_data, batch_size = 64, shuffle = False) # Device 설정 device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f'device = {device}')
- device = cuda
EDA(Exploratory Data Analysis, 탐색적 데이터 분석)
print(training_data, '\n------------------\n', test_data) training_data[0] train_features, train_labels = next(iter(train_dataloader)) print(f"Feature batch shape: {train_features.size()}") print(f"Labels batch shape: {train_labels.size()}") img, label = training_data[0] plt.imshow(img.squeeze(), cmap='gray') print(f'label={label}') labels_map = { 0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot", } figure = plt.figure(figsize = (20, 8)) cols, rows = 5, 2 for i in range(1, cols * rows +1): sample_idx = torch.randint(len(training_data), size=(1,)).item() img, label = training_data[sample_idx] figure.add_subplot(rows, cols, i) plt.title(labels_map[label]) print(labels_map[label]) plt.axis('off') plt.imshow(img.squeeze(), cmap='gray') plt.show()
- Dataset FashionMNIST
Number of datapoints: 60000
Root location: data
Split: Train
StandardTransform
Transform: ToTensor()
------------------
Dataset FashionMNIST
Number of datapoints: 10000
Root location: data
Split: Test
StandardTransform
Transform: ToTensor() - (tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0510,
0.2863, 0.0000, 0.0000, 0.0039, 0.0157, 0.0000, 0.0000, 0.0000,
0.0000, 0.0039, 0.0039, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.1412, 0.5333,
0.4980, 0.2431, 0.2118, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118,
0.0157, 0.0000, 0.0000, 0.0118],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0000, 0.4000, 0.8000,
0.6902, 0.5255, 0.5647, 0.4824, 0.0902, 0.0000, 0.0000, 0.0000,
0.0000, 0.0471, 0.0392, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6078, 0.9255,
0.8118, 0.6980, 0.4196, 0.6118, 0.6314, 0.4275, 0.2510, 0.0902,
0.3020, 0.5098, 0.2824, 0.0588],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.2706, 0.8118, 0.8745,
0.8549, 0.8471, 0.8471, 0.6392, 0.4980, 0.4745, 0.4784, 0.5725,
0.5529, 0.3451, 0.6745, 0.2588],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0039, 0.0039, 0.0039, 0.0000, 0.7843, 0.9098, 0.9098,
0.9137, 0.8980, 0.8745, 0.8745, 0.8431, 0.8353, 0.6431, 0.4980,
0.4824, 0.7686, 0.8980, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7176, 0.8824, 0.8471,
0.8745, 0.8941, 0.9216, 0.8902, 0.8784, 0.8706, 0.8784, 0.8667,
0.8745, 0.9608, 0.6784, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7569, 0.8941, 0.8549,
0.8353, 0.7765, 0.7059, 0.8314, 0.8235, 0.8275, 0.8353, 0.8745,
0.8627, 0.9529, 0.7922, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0039, 0.0118, 0.0000, 0.0471, 0.8588, 0.8627, 0.8314,
0.8549, 0.7529, 0.6627, 0.8902, 0.8157, 0.8549, 0.8784, 0.8314,
0.8863, 0.7725, 0.8196, 0.2039],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0235, 0.0000, 0.3882, 0.9569, 0.8706, 0.8627,
0.8549, 0.7961, 0.7765, 0.8667, 0.8431, 0.8353, 0.8706, 0.8627,
0.9608, 0.4667, 0.6549, 0.2196],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0157, 0.0000, 0.0000, 0.2157, 0.9255, 0.8941, 0.9020,
0.8941, 0.9412, 0.9098, 0.8353, 0.8549, 0.8745, 0.9176, 0.8510,
0.8510, 0.8196, 0.3608, 0.0000],
[0.0000, 0.0000, 0.0039, 0.0157, 0.0235, 0.0275, 0.0078, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.8863, 0.8510, 0.8745,
0.8706, 0.8588, 0.8706, 0.8667, 0.8471, 0.8745, 0.8980, 0.8431,
0.8549, 1.0000, 0.3020, 0.0000],
[0.0000, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.2431, 0.5686, 0.8000, 0.8941, 0.8118, 0.8353, 0.8667,
0.8549, 0.8157, 0.8275, 0.8549, 0.8784, 0.8745, 0.8588, 0.8431,
0.8784, 0.9569, 0.6235, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.1725, 0.3216, 0.4196,
0.7412, 0.8941, 0.8627, 0.8706, 0.8510, 0.8863, 0.7843, 0.8039,
0.8275, 0.9020, 0.8784, 0.9176, 0.6902, 0.7373, 0.9804, 0.9725,
0.9137, 0.9333, 0.8431, 0.0000],
[0.0000, 0.2235, 0.7333, 0.8157, 0.8784, 0.8667, 0.8784, 0.8157,
0.8000, 0.8392, 0.8157, 0.8196, 0.7843, 0.6235, 0.9608, 0.7569,
0.8078, 0.8745, 1.0000, 1.0000, 0.8667, 0.9176, 0.8667, 0.8275,
0.8627, 0.9098, 0.9647, 0.0000],
[0.0118, 0.7922, 0.8941, 0.8784, 0.8667, 0.8275, 0.8275, 0.8392,
0.8039, 0.8039, 0.8039, 0.8627, 0.9412, 0.3137, 0.5882, 1.0000,
0.8980, 0.8667, 0.7373, 0.6039, 0.7490, 0.8235, 0.8000, 0.8196,
0.8706, 0.8941, 0.8824, 0.0000],
[0.3843, 0.9137, 0.7765, 0.8235, 0.8706, 0.8980, 0.8980, 0.9176,
0.9765, 0.8627, 0.7608, 0.8431, 0.8510, 0.9451, 0.2549, 0.2863,
0.4157, 0.4588, 0.6588, 0.8588, 0.8667, 0.8431, 0.8510, 0.8745,
0.8745, 0.8784, 0.8980, 0.1137],
[0.2941, 0.8000, 0.8314, 0.8000, 0.7569, 0.8039, 0.8275, 0.8824,
0.8471, 0.7255, 0.7725, 0.8078, 0.7765, 0.8353, 0.9412, 0.7647,
0.8902, 0.9608, 0.9373, 0.8745, 0.8549, 0.8314, 0.8196, 0.8706,
0.8627, 0.8667, 0.9020, 0.2627],
[0.1882, 0.7961, 0.7176, 0.7608, 0.8353, 0.7725, 0.7255, 0.7451,
0.7608, 0.7529, 0.7922, 0.8392, 0.8588, 0.8667, 0.8627, 0.9255,
0.8824, 0.8471, 0.7804, 0.8078, 0.7294, 0.7098, 0.6941, 0.6745,
0.7098, 0.8039, 0.8078, 0.4510],
[0.0000, 0.4784, 0.8588, 0.7569, 0.7020, 0.6706, 0.7176, 0.7686,
0.8000, 0.8235, 0.8353, 0.8118, 0.8275, 0.8235, 0.7843, 0.7686,
0.7608, 0.7490, 0.7647, 0.7490, 0.7765, 0.7529, 0.6902, 0.6118,
0.6549, 0.6941, 0.8235, 0.3608],
[0.0000, 0.0000, 0.2902, 0.7412, 0.8314, 0.7490, 0.6863, 0.6745,
0.6863, 0.7098, 0.7255, 0.7373, 0.7412, 0.7373, 0.7569, 0.7765,
0.8000, 0.8196, 0.8235, 0.8235, 0.8275, 0.7373, 0.7373, 0.7608,
0.7529, 0.8471, 0.6667, 0.0000],
[0.0078, 0.0000, 0.0000, 0.0000, 0.2588, 0.7843, 0.8706, 0.9294,
0.9373, 0.9490, 0.9647, 0.9529, 0.9569, 0.8667, 0.8627, 0.7569,
0.7490, 0.7020, 0.7137, 0.7137, 0.7098, 0.6902, 0.6510, 0.6588,
0.3882, 0.2275, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1569,
0.2392, 0.1725, 0.2824, 0.1608, 0.1373, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000]]]),
9) - Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
[Step3] Set Network Structure
class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() self.flatten = nn.Flatten() self.classifier = nn.Sequential( nn.Linear(28*28, 128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, 10), ) def forward(self, x): x = self.flatten(x) output = self.classifier(x) return output
[Step4] Create Model instance
model = NeuralNetwork().to(device) print(model)
- NeuralNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(classifier): Sequential(
(0): Linear(in_features=784, out_features=128, bias=True)
(1): ReLU()
(2): Dropout(p=0.2, inplace=False)
(3): Linear(in_features=128, out_features=10, bias=True)
)
)
Model 테스트
X = torch.rand(1, 28, 28, device = device) output = model(X) print(f'모델 출력 결과: {output}\n') pred_probab = nn.Softmax(dim=1)(output) print(f'Softmax 결과: {pred_probab}\n') y_pred = pred_probab.argmax() print(y_pred)
- 모델 출력 결과: tensor([[ 0.3374, -0.2491, 0.2523, -0.0858, -0.0171, -0.0398, -0.0953, -0.1535,
-0.0450, 0.0389]], device='cuda:0', grad_fn=<AddmmBackward0>)
Softmax 결과: tensor([[0.1388, 0.0772, 0.1275, 0.0909, 0.0974, 0.0952, 0.0901, 0.0850, 0.0947,
0.1030]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
tensor(0, device='cuda:0')
[Step5] Model compile
# Loss loss = nn.CrossEntropyLoss() # Optimizer learning_rate = 1e-3 #0.001 optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
[Step6] Set train loop
def train_loop(train_loader, model, loss_fn, optimizer): size = len(train_loader.dataset) for batch, (X, y) in enumerate(train_loader): X, y = X.to(device), y.to(device) pred = model(X) # 손실 계산 loss = loss_fn(pred, y) # 역전파 optimizer.zero_grad() loss.backward() optimizer.step() if batch % 100 == 0: loss, current = loss.item(), batch * len(X) print(f'loss: {loss:>7f} [{current:>5d}]/{size:5d}')
[Step7] Set test loop
def test_loop(test_loader, model, loss_fn): size = len(test_loader.dataset) num_batches = len(test_loader) test_loss, correct = 0, 0 with torch.no_grad(): for X, y in test_loader: X, y = X.to(device), y.to(device) pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:8f}\n")
[Step8] Run model
epochs = 10 for i in range(epochs) : print(f"Epoch {i+1} \n------------------------") train_loop(train_dataloader, model, loss, optimizer) test_loop(test_dataloader, model, loss) print("Done!")
- Epoch 1
------------------------
loss: 2.324765 [ 0]/60000
loss: 0.691231 [ 6400]/60000
loss: 0.471458 [12800]/60000
loss: 0.613083 [19200]/60000
loss: 0.444306 [25600]/60000
loss: 0.549569 [32000]/60000
loss: 0.454276 [38400]/60000
loss: 0.423303 [44800]/60000
loss: 0.461990 [51200]/60000
loss: 0.667872 [57600]/60000
Test Error:
Accuracy: 81.8%, Avg loss: 0.497612
Epoch 2
------------------------
loss: 0.371661 [ 0]/60000
loss: 0.409686 [ 6400]/60000
loss: 0.383233 [12800]/60000
loss: 0.420215 [19200]/60000
loss: 0.388781 [25600]/60000
loss: 0.447195 [32000]/60000
loss: 0.476783 [38400]/60000
loss: 0.300698 [44800]/60000
loss: 0.307176 [51200]/60000
loss: 0.498841 [57600]/60000
Test Error:
Accuracy: 84.8%, Avg loss: 0.427579
• • •
Epoch 9
------------------------
loss: 0.353967 [ 0]/60000
loss: 0.383568 [ 6400]/60000
loss: 0.510100 [12800]/60000
loss: 0.307564 [19200]/60000
loss: 0.241812 [25600]/60000
loss: 0.327063 [32000]/60000
loss: 0.312592 [38400]/60000
loss: 0.401675 [44800]/60000
loss: 0.229188 [51200]/60000
loss: 0.273675 [57600]/60000
Test Error:
Accuracy: 87.4%, Avg loss: 0.362324
Epoch 10
------------------------
loss: 0.293197 [ 0]/60000
loss: 0.239528 [ 6400]/60000
loss: 0.118810 [12800]/60000
loss: 0.316047 [19200]/60000
loss: 0.220012 [25600]/60000
loss: 0.263426 [32000]/60000
loss: 0.263754 [38400]/60000
loss: 0.349957 [44800]/60000
loss: 0.210180 [51200]/60000
loss: 0.347154 [57600]/60000
Test Error:
Accuracy: 86.8%, Avg loss: 0.367648
Done!
[Step9] Save & load model
parameter만 저장하고 불러오기
torch.save(model.state_dict(), 'model_weights.pth') model2 = NeuralNetwork().to(device) print(model2) model2.load_state_dict(torch.load('model_weights.pth')) model2.eval() test_loop(test_dataloader, model2, loss)
- NeuralNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(classifier): Sequential(
(0): Linear(in_features=784, out_features=128, bias=True)
(1): ReLU()
(2): Dropout(p=0.2, inplace=False)
(3): Linear(in_features=128, out_features=10, bias=True)
)
) - <All keys matched successfully>
- Test Error:
Accuracy: 88.0%, Avg loss: 0.338681
Model 전체를 저장하고 불러오기
torch.save(model, 'model.pth') model3 = torch.load('model.pth') model3.eval() test_loop(test_dataloader, model3, loss)
- Test Error:
Accuracy: 88.0%, Avg loss: 0.338681
반응형'인공지능 > 프레임워크 or 라이브러리' 카테고리의 다른 글
VGGNet을 사용한 이미지 분류기 실습 (2) 2024.03.19 AlexNet을 사용한 이미지 분류기 실습 (0) 2024.03.18 인공 신경망 코드로 구현해서 다중 분류해보기 (1) (0) 2024.03.16 자전거 대여량 예측 - 선형 회귀, 군집 모델 (클러스터링) 실습 (4) 2024.03.15 제조 데이터의 분류기 실습 (0) 2024.03.14 다음글이 없습니다.이전글이 없습니다.댓글