PyTorch Version - vai_p_pytorch

The pruning tool on PyTorch is a Python package rather than an executable program. Use the pruning APIs to prune the model.

Preparing a Baseline Model

For simplicity, ResNet18 from torchvision is used here. In real life applications, the process of creating a model can be quite complicated.

from torchvision.models.resnet import resnet18
model = resnet18(pretrained=True)

Creating a Pruner

The Pruner class requires two arguments.

  • The model to be pruned
  • The inference inputs
Note: It is not necessary for the input to be real data. It can be randomly generated dummy data as long as it has the same shape and type as the real data.
import torch
from pytorch_nndct import Pruner

inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32)
pruner = Pruner(model, inputs)

For models with multiple inputs, you can use a list or a tuple of inputs to initialize a pruner.

Model Analysis

To run model analysis, a evaluation function needs to be passed to the pruner.ana() function. A limitation to this evaluation function is that the first argument must be the model to be evaluated. Generally, the existing evaluation function does not meet the requirement and you must define a wrapper function as shown below.

Consider this as your evaluation function:

def evaluate(val_loader, model, criterion):
  batch_time = AverageMeter('Time', ':6.3f')
  losses = AverageMeter('Loss', ':.4e')
  top1 = AverageMeter('Acc@1', ':6.2f')
  top5 = AverageMeter('Acc@5', ':6.2f')
  progress = ProgressMeter(
      len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ')

  # switch to evaluate mode
  model.eval()

  with torch.no_grad():
    end = time.time()
    for i, (images, target) in enumerate(val_loader):
      model = model.cuda()
      images = images.cuda(non_blocking=True)
      target = target.cuda(non_blocking=True)

      # compute output
      output = model(images)
      loss = criterion(output, target)

      # measure accuracy and record loss
      acc1, acc5 = accuracy(output, target, topk=(1, 5))
      losses.update(loss.item(), images.size(0))
      top1.update(acc1[0], images.size(0))
      top5.update(acc5[0], images.size(0))

      # measure elapsed time
      batch_time.update(time.time() - end)
      end = time.time()

      if i % 50 == 0:
        progress.display(i)

    # TODO: this should also be done with the ProgressMeter
    print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(
        top1=top1, top5=top5))

  return top1.avg, top5.avg

Define a wrapper to meet the evaluation function requirements:

def ana_eval_fn(model, val_loader, loss_fn):
  return evaluate(val_loader, model, loss_fn)[1]

Then, call ana() method with the function defined above as the first argument.

pruner.ana(ana_eval_fn, args=(val_loader, criterion))

Here, the ‘args’ is the tuple of arguments starting from the second argument required by ‘ana_eval_fn’.

Pruning the Model

Call prune() method to get a pruned model. The ratio is the proportion of FLOPs expected to be reduced and output_script specifies the path of the Python script that can be used to rebuild the pruned model.

model = pruner.prune(ratio=0.1, output_script='pruned_resnet18.py')

Finetuning the Pruned Model

The process of fine-tuning is the same as training a baseline model. The difference is that the weights of the baseline model are randomly initialized and the weights of the pruned model are inherited from the baseline model.

class AverageMeter(object):
  """Computes and stores the average and current value"""

  def __init__(self, name, fmt=':f'):
    self.name = name
    self.fmt = fmt
    self.reset()

  def reset(self):
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0

  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count

  def __str__(self):
    fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
    return fmtstr.format(**self.__dict__)

def train(train_loader, model, criterion, optimizer, epoch):
  batch_time = AverageMeter('Time', ':6.3f')
  data_time = AverageMeter('Data', ':6.3f')
  losses = AverageMeter('Loss', ':.4e')
  top1 = AverageMeter('Acc@1', ':6.2f')
  top5 = AverageMeter('Acc@5', ':6.2f')

  # switch to train mode
  model.train()

  end = time.time()
  for i, (images, target) in enumerate(train_loader):
    # measure data loading time
    data_time.update(time.time() - end)

    model = model.cuda()
    images = images.cuda()
    target = target.cuda()

    # compute output
    output = model(images)
    loss = criterion(output, target)

    # measure accuracy and record loss
    acc1, acc5 = accuracy(output, target, topk=(1, 5))
    losses.update(loss.item(), images.size(0))
    top1.update(acc1[0], images.size(0))
    top5.update(acc5[0], images.size(0))

    # compute gradient and do SGD step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
	
	# measure elapsed time
    batch_time.update(time.time() - end)
    end = time.time()
	
	if i % 10 == 0:
         print('Epoch: [{}] Acc@1 {} Acc@5 {}'.format(epoch, top1.avg, top5.avg)

Next, run the training loop. Here the parameter ‘model’ in train() function is the returned object from prune() method.

lr = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=1e-4)

best_acc5 = 0
epochs = 10
for epoch in range(epochs):
  train(train_loader, model, criterion, optimizer, epoch)

  acc1, acc5 = evaluate(val_loader, model, criterion)

  # remember best acc@1 and save checkpoint
  is_best = acc5 > best_acc5
  best_acc5 = max(acc5, best_acc5)

  if is_best:
    torch.save(model.state_dict(), 'resnet18_sparse.pth')
    torch.save(model.pruned_state_dict(), 'resnet18_dense.pth')
Note: In the last two lines of code, two checkpoint files are saved. ‘model.state_dict()’ returns sparse weights with the same shapes as the baseline model and the weights of pruned channels are set to 0. ‘model.pruned_state_dict()’ returns dense weights with pruned shapes. The first checkpoint is used as input for the next round of pruning, and the second checkpoint is used for the final deployment. You should always use the first checkpoint before the end of multiple rounds of pruning. After the pruning is over, you should use the second checkpoint.

Iterative Pruning

Load the sparse checkpoint and increase pruning ratio. Here, the pruning ratio is increased from 0.1 to 0.2.

model = resnet18()
model.load_state_dict(torch.load('resnet18_sparse.pth.tar'))

inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32)
pruner = Pruner(model, inputs)
model = pruner.prune(ratio=0.2, output_script='pruned_resnet18.py')

Once the new pruned model is generated, you can start fine-tuning again.

Using a Pruned Model

You can use the generated Python script to create a pruned model and then load dense weights on it. The dense weights are obtained and saved by ‘pruned_state_dict()’ during fine-tuning.

from pruned_resnet18 import ResNet

model = ResNet()
model.load_state_dict(torch.load('resnet18_dense_best.pth'))

vai_p_pytorch APIs

pytorch_nndct.Pruner

Implements channel pruning at the module level.

Arguments

Pruner(module, inputs)

Create a new pruner object.

module
A torch.nn.Module object to be pruned.
inputs
The inputs of the module.

Methods

  • ana(eval_fn, args=(), gpus=None)

    Performs model analysis.

    eval_fn
    Callable object that takes a torch.nn.Module object as its first argument and returns the evaluation score.
    args
    A tuple of arguments that will be passed to eval_fn.
    gpus
    A tuple or list of GPU indices used for model analysis. If not set, the default GPU will be used.
  • prune(ratio=None, threshold=None, excludes=None, output_script='graph.py')

    Pruning the network by a given ratio or threshold returns an ‘torch.nn.Module’ object. The difference between the returned object and the torch native module is that it has one more method named ‘pruned_state_dict()’, by which you can get parameters of the pruned dense model. The weights returned by ‘pruned_state_dict()’ can be loaded into the model created with Python in the ‘output_script’ file.

    ratio
    The expected percentage of FLOPs reduction. This is an approximation. The actual percentage may not drop strictly to this value after pruning.
    threshold
    Relative proportion of model performance loss that can be tolerated.
    excludes
    Modules that need to prevent from pruning.
    output_script
    Filepath that saves the generated script used for rebuilding model.