
Demo on how to save and load models
Authors: Eric Roberts and Petrus Zwart
E-mail: PHZwart@lbl.gov, EJRoberts@lbl.gov
This notebook highlights some basic functionality with the dlsia package.
Using the dlsia framework, we initialize convolutional neural networks, train each on a small dataset using the cpu, and show how to save and load them.
[1]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from dlsia.core import helpers, custom_losses, train_scripts
from dlsia.core.networks import msdnet, tunet, tunet3plus, smsnet
import matplotlib.pyplot as plt
Create & load data
Generate random data
Let’s build some random data: 40 instances of single channel, 36-by-36 images.
[2]:
n_imgs = 40
n_channels = 1
n_xy = 36
random_data1 = torch.rand((n_imgs, n_channels, n_xy, n_xy))
k = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=11, stride=1, padding=5)
random_data2 = torch.rand((n_imgs, n_channels, n_xy, n_xy))
random_data_gt = k(random_data1)
random_data_obs = k(random_data1) + random_data2*0.50
K = 3
random_data_obs = random_data_obs[:,:,K:32+K,K:32+K].detach()
random_data_gt = random_data_gt[:,:,K:32+K,K:32+K].detach()
train_x = random_data_obs[:20,...]
train_y = random_data_gt[:20,...]
test_x = random_data_obs[20:,...]
test_y = random_data_gt[20:,...]
Prep data
We cast data as tensors for dlsia pipeline ingestion by making liberal use of the PyTorch Dataloader. This allows us to easy handle and iterative load data into the networks and models.
[3]:
train_set = TensorDataset( train_x, train_y)
test_set = TensorDataset( train_x, train_y)
# Specify batch sizes
batch_size_train = 20
batch_size_test = 20
# Set Dataloader parameters (Note: we randomly shuffle the training set upon each pass)
train_loader_params = {'batch_size': batch_size_train,'shuffle': True}
test_loader_params = {'batch_size': batch_size_test, 'shuffle': False}
# Build Dataloaders
train_loader = DataLoader(train_set, **train_loader_params)
test_loader = DataLoader(test_set, **test_loader_params)
Construct Networks
dlsia offers a variety of different convolutional neural network architectures.
MSDNet
Mixed-scale Dense networks that probe different length scales using dilated convolutions.
[4]:
msdnet_model = msdnet.MixedScaleDenseNetwork(in_channels = 1,
out_channels = 1,
num_layers=40)
TUNet
Tuneable U-Nets with a variety of user-customizable parameters.
[5]:
tunet_model = tunet.TUNet(image_shape=(32,32),
in_channels=1,
out_channels=1,
depth=3,
base_channels=10)
TUNet3+
A newer UNet modification connecting all encoder and decoder layers via carefully crafted upsampling/downsampling/convolution/concatenation bundles.
[6]:
tunet3plus_model = tunet3plus.TUNet3Plus(image_shape=(32,32),
in_channels=1,
out_channels=1,
depth=3,
base_channels=10)
SMSNets
Sparse Mixed-Scale Networks that lean, randomly & sparsely connected variants of MSDNets.
[7]:
smsnet_model = smsnet.random_SMS_network(in_channels=1,
out_channels=1,
hidden_out_channels=[1],
layers=40,
dilation_choices=[1,2,3,4],
#layer_probabilities=layer_probabilities,
network_type="Regression")
[8]:
# View number of learnable parameters in each network
print("MSDNet : ", helpers.count_parameters(msdnet_model), "parameters")
print("TUNet : ", helpers.count_parameters(tunet_model), "parameters")
print("TUNet3plus : ", helpers.count_parameters(tunet3plus_model), "parameters")
print("SMSNet : ", helpers.count_parameters(smsnet_model), "parameters")
MSDNet : 9182 parameters
TUNet : 46131 parameters
TUNet3plus : 49421 parameters
SMSNet : 735 parameters
Training
Training parameters
Training hyperparameters are chosen
[9]:
epochs = 30
criterion = nn.MSELoss()
learning_rate = 1e-2
# Define optimizers, one per network
optimizer_msd = optim.Adam(msdnet_model.parameters(), lr=learning_rate)
optimizer_tunet = optim.Adam(tunet_model.parameters(), lr=learning_rate)
optimizer_tunet3plus = optim.Adam(tunet3plus_model.parameters(), lr=learning_rate)
optimizer_smsnet = optim.Adam(smsnet_model.parameters(), lr=learning_rate)
device = "cpu"
#device = helpers.get_device() # Uncomment to get detected GPU
print('Device we will compute on: ', device) # cuda:0 for GPU. Else, CPU
Device we will compute on: cpu
Training loops
[10]:
msdnet_model.to(device)
msdnet_model, results = train_scripts.train_regression(msdnet_model,
train_loader,
test_loader,
epochs,
criterion,
optimizer_msd,
device,
show=10)
msdnet_model = msdnet_model.cpu()
# clear out unnecessary variables from device (GPU) memory
#torch.cuda.empty_cache()
Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 3.0587e-02 | Validation Loss: 2.9307e-02
Training CC: 0.2502 Validation CC : 0.2717
Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 2.5751e-02 | Validation Loss: 2.5685e-02
Training CC: 0.3672 Validation CC : 0.3708
Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 2.4834e-02 | Validation Loss: 2.4734e-02
Training CC: 0.4068 Validation CC : 0.4111
[11]:
tunet_model.to(device)
tunet_model, results = train_scripts.train_regression(tunet_model,
train_loader,
test_loader,
epochs,
criterion,
optimizer_tunet,
device,
show=10)
tunet_model = tunet_model.cpu()
Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 2.1023e-02 | Validation Loss: 2.0195e-02
Training CC: 0.5513 Validation CC : 0.5673
Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 1.6087e-02 | Validation Loss: 1.5932e-02
Training CC: 0.6799 Validation CC : 0.6858
Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 1.4304e-02 | Validation Loss: 1.4126e-02
Training CC: 0.7219 Validation CC : 0.7252
[12]:
tunet3plus_model.to(device)
tunet3plus_model, results = train_scripts.train_regression(tunet3plus_model,
train_loader,
test_loader,
epochs,
criterion,
optimizer_tunet3plus,
device,
show=10)
tunet3plus_model = tunet3plus_model.cpu()
Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 2.9115e-02 | Validation Loss: 2.7833e-02
Training CC: 0.4941 Validation CC : 0.5236
Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 1.7936e-02 | Validation Loss: 1.7812e-02
Training CC: 0.6395 Validation CC : 0.6462
Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 1.5100e-02 | Validation Loss: 1.4969e-02
Training CC: 0.7016 Validation CC : 0.7050
[13]:
smsnet_model.to(device)
smsnet_model, results = train_scripts.train_regression(smsnet_model,
train_loader,
test_loader,
epochs,
criterion,
optimizer_smsnet,
device,
show=10)
smsnet_model = smsnet_model.cpu()
Epoch 10 of 30 | Learning rate 1.000e-02
Training Loss: 1.9555e-02 | Validation Loss: 1.8536e-02
Training CC: 0.5966 Validation CC : 0.6203
Epoch 20 of 30 | Learning rate 1.000e-02
Training Loss: 1.5011e-02 | Validation Loss: 1.4870e-02
Training CC: 0.7039 Validation CC : 0.7076
Epoch 30 of 30 | Learning rate 1.000e-02
Training Loss: 1.3773e-02 | Validation Loss: 1.3678e-02
Training CC: 0.7334 Validation CC : 0.7355
Save Networks
Each network library contains submodule for saving the trained networks. Each instance saves in a .pt file the following:
model’s state_dict: the network parameters learned through optimization/minimization during training,
model’s topo_dict: the list of network hyperparameters needed to initialize the same architecture.
This follows standard PyTorch practice; instead of saving massive trained networks, the pickled weights may simply be loaded into a freshly created network.
[14]:
msdnet_model.save_network_parameters("this_msdnet.pt")
smsnet_model.save_network_parameters("this_smsnet.pt")
tunet_model.save_network_parameters("this_tunet.pt")
Load networks from file
Each network library loads in the .pt file containing architecture-governing hyperparameters and learned weights.
[15]:
copy_msdnet = msdnet.MSDNetwork_from_file("this_msdnet.pt")
copy_smsnet = smsnet.SMSNetwork_from_file("this_smsnet.pt")
copy_tunet = tunet.TUNetwork_from_file("this_tunet.pt")
Verify loaded networks
Network copies are loaded from file and checked against the originals.
[16]:
with torch.no_grad():
r1 = msdnet_model(test_x)
r2 = copy_msdnet(test_x)
delta = r1-r2
assert torch.max(torch.abs(delta)) < 1e-8
[17]:
with torch.no_grad():
r1 = smsnet_model(test_x)
r2 = copy_smsnet(test_x)
delta = r1-r2
assert torch.max(torch.abs(delta)) < 1e-8
[18]:
with torch.no_grad():
r1 = tunet_model(test_x)
r2 = copy_tunet(test_x)
delta = r1-r2
assert torch.max(torch.abs(delta)) < 1e-8