PyTorch Version - vai_p_pytorch
The pruning tool on PyTorch is a Python package rather than an executable program. Use the pruning APIs to prune the model.
Preparing a Baseline Model
For simplicity, ResNet18 from torchvision
is
used here. In real life applications, the process of creating a model can be quite
complicated.
from torchvision.models.resnet import resnet18
model = resnet18(pretrained=True)
Creating a Pruner
The Pruner
class requires two arguments.
- The model to be pruned
- The inference inputs
import torch
from pytorch_nndct import Pruner
inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32)
pruner = Pruner(model, inputs)
For models with multiple inputs, you can use a list or a tuple of inputs to initialize a pruner.
Model Analysis
To run model analysis, a evaluation function needs to be passed to the pruner.ana()
function. A limitation to this evaluation function is
that the first argument must be the model to be evaluated. Generally, the existing evaluation
function does not meet the requirement and you must define a wrapper function as shown
below.
Consider this as your evaluation function:
def evaluate(val_loader, model, criterion):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
model = model.cuda()
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 50 == 0:
progress.display(i)
# TODO: this should also be done with the ProgressMeter
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(
top1=top1, top5=top5))
return top1.avg, top5.avg
Define a wrapper to meet the evaluation function requirements:
def ana_eval_fn(model, val_loader, loss_fn):
return evaluate(val_loader, model, loss_fn)[1]
Then, call ana()
method with the function
defined above as the first argument.
pruner.ana(ana_eval_fn, args=(val_loader, criterion))
Here, the ‘args’
is the tuple of arguments
starting from the second argument required by ‘ana_eval_fn’
.
Pruning the Model
Call prune()
method to get a pruned model. The
ratio
is the proportion of FLOPs expected to be
reduced and output_script
specifies the path of the
Python script that can be used to rebuild the pruned model.
model = pruner.prune(ratio=0.1, output_script='pruned_resnet18.py')
Finetuning the Pruned Model
The process of fine-tuning is the same as training a baseline model. The difference is that the weights of the baseline model are randomly initialized and the weights of the pruned model are inherited from the baseline model.
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
# switch to train mode
model.train()
end = time.time()
for i, (images, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
model = model.cuda()
images = images.cuda()
target = target.cuda()
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0:
print('Epoch: [{}] Acc@1 {} Acc@5 {}'.format(epoch, top1.avg, top5.avg)
Next, run the training loop. Here the parameter ‘model’ in train()
function is the returned object from prune()
method.
lr = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=1e-4)
best_acc5 = 0
epochs = 10
for epoch in range(epochs):
train(train_loader, model, criterion, optimizer, epoch)
acc1, acc5 = evaluate(val_loader, model, criterion)
# remember best acc@1 and save checkpoint
is_best = acc5 > best_acc5
best_acc5 = max(acc5, best_acc5)
if is_best:
torch.save(model.state_dict(), 'resnet18_sparse.pth')
torch.save(model.pruned_state_dict(), 'resnet18_dense.pth')
‘model.state_dict()’
returns
sparse weights with the same shapes as the baseline model and the weights of pruned channels
are set to 0. ‘model.pruned_state_dict()’
returns dense
weights with pruned shapes. The first checkpoint is used as input for the next round of
pruning, and the second checkpoint is used for the final deployment. You should always use the
first checkpoint before the end of multiple rounds of pruning. After the pruning is over, you
should use the second checkpoint.Iterative Pruning
Load the sparse checkpoint and increase pruning ratio. Here, the pruning ratio is increased from 0.1 to 0.2.
model = resnet18()
model.load_state_dict(torch.load('resnet18_sparse.pth.tar'))
inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32)
pruner = Pruner(model, inputs)
model = pruner.prune(ratio=0.2, output_script='pruned_resnet18.py')
Once the new pruned model is generated, you can start fine-tuning again.
Using a Pruned Model
You can use the generated Python script to create a pruned model and then
load dense weights on it. The dense weights are obtained and saved by ‘pruned_state_dict()’
during fine-tuning.
from pruned_resnet18 import ResNet
model = ResNet()
model.load_state_dict(torch.load('resnet18_dense_best.pth'))
vai_p_pytorch APIs
pytorch_nndct.Pruner
Implements channel pruning at the module level.
Arguments
Pruner(module, inputs)
Create a new pruner object.
- module
- A
torch.nn.Module
object to be pruned. - inputs
- The inputs of the module.
Methods
-
ana(eval_fn, args=(), gpus=None)
Performs model analysis.
- eval_fn
- Callable object that takes a
torch.nn.Module
object as its first argument and returns the evaluation score. - args
- A tuple of arguments that will be passed to
eval_fn
. - gpus
- A tuple or list of GPU indices used for model analysis. If not set, the default GPU will be used.
-
prune(ratio=None, threshold=None, excludes=None, output_script='graph.py')
Pruning the network by a given ratio or threshold returns an
‘torch.nn.Module’
object. The difference between the returned object and the torch native module is that it has one more method named‘pruned_state_dict()’
, by which you can get parameters of the pruned dense model. The weights returned by‘pruned_state_dict()’
can be loaded into the model created with Python in the ‘output_script’ file.- ratio
- The expected percentage of FLOPs reduction. This is an approximation. The actual percentage may not drop strictly to this value after pruning.
- threshold
- Relative proportion of model performance loss that can be tolerated.
- excludes
- Modules that need to prevent from pruning.
- output_script
- Filepath that saves the generated script used for rebuilding model.