Advanced Usage¶
Here we will introduce some advanced usage of QSPARSE by topics. More information can be found at API Reference.
%load_ext autoreload
%autoreload 2
import logging
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler()])
from qsparse import set_qsparse_options
set_qsparse_options(log_on_created=False)
Layerwise Pruning¶
The function devise_layerwise_pruning_schedule
will traverse all pruning operator
throughout the network from input and assign the step for each operator to be activated, to ensure that each pruning operator is activated after all its preceding layers are pruned. The motivation and algorithm details can be found in our MDPI publication.
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from qsparse.sparse import prune, devise_layerwise_pruning_schedule
net = nn.Sequential(nn.Conv2d(3, 3, 3),
prune(sparsity=0.5), # no need to specify `start, repetition, interval`
nn.Conv2d(3, 3, 3),
prune(sparsity=0.5))
devise_layerwise_pruning_schedule(net, start=1, interval=10) # notice the `start` of each prune layer increases
Pruning stops at iteration - 23
Sequential( (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) (1): PruneLayer(sparsity=0.5, start=1, interval=10, repetition=1, dimensions={1}) (2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) (3): PruneLayer(sparsity=0.5, start=12, interval=10, repetition=1, dimensions={1}) )
Network Conversion¶
The function convert
comes in handy in producing pruned and quantized network instance without touching the existing floating-point network implementation. Here we introduce some frequent usage.
1. Inserting pruning operator after all ReLU layers¶
from collections import OrderedDict
from qsparse import convert, quantize, prune
net = nn.Sequential(OrderedDict([
("first_half", nn.Sequential(nn.Conv2d(3, 3, 3), nn.ReLU())),
("second_half", nn.Sequential(nn.Conv2d(3, 3, 3), nn.ReLU()))]))
convert(net, prune(sparsity=0.5), activation_layers=[nn.ReLU], inplace=False)
Apply `prunesparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}` on the .first_half.1 activation Apply `prunesparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}` on the .second_half.1 activation
Sequential( (first_half): Sequential( (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) (1): Sequential( (0): ReLU() (1): PruneLayer(sparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}) ) ) (second_half): Sequential( (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) (1): Sequential( (0): ReLU() (1): PruneLayer(sparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}) ) ) )
2. Applying the quantization operator on the weight of all Conv2D layers¶
convert(net, quantize(bits=4), weight_layers=[nn.Conv2d], inplace=False)
Apply `quantizebits=4, timeout=1000, callback=scalerquantizer, channelwise=1` on the .first_half.0 weight Apply `quantizebits=4, timeout=1000, callback=scalerquantizer, channelwise=1` on the .second_half.0 weight
Sequential( (first_half): Sequential( (0): Conv2d( 3, 3, kernel_size=(3, 3), stride=(1, 1) (quantize): QuantizeLayer(bits=4, timeout=1000, callback=ScalerQuantizer, channelwise=1) ) (1): ReLU() ) (second_half): Sequential( (0): Conv2d( 3, 3, kernel_size=(3, 3), stride=(1, 1) (quantize): QuantizeLayer(bits=4, timeout=1000, callback=ScalerQuantizer, channelwise=1) ) (1): ReLU() ) )
3. Applying (1) and (2), but excluding the last ReLU and the first Conv2D layer¶
convert(convert(net, prune(sparsity=0.5), activation_layers=[nn.ReLU],
excluded_activation_layer_indexes=[(nn.ReLU, [-1])], inplace=False),
quantize(bits=4), weight_layers=[nn.Conv2d],
excluded_weight_layer_indexes=[(nn.Conv2d, [0])], inplace=False)
Apply `prunesparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}` on the .first_half.1 activation Exclude .second_half.1 activation Exclude .first_half.0 weight Apply `quantizebits=4, timeout=1000, callback=scalerquantizer, channelwise=1` on the .second_half.0 weight
Sequential( (first_half): Sequential( (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) (1): Sequential( (0): ReLU() (1): PruneLayer(sparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}) ) ) (second_half): Sequential( (0): Conv2d( 3, 3, kernel_size=(3, 3), stride=(1, 1) (quantize): QuantizeLayer(bits=4, timeout=1000, callback=ScalerQuantizer, channelwise=1) ) (1): ReLU() ) )
4. Only insert pruning at the first half of the network¶
convert(net, prune(sparsity=0.5), activation_layers=[nn.ReLU], include=['first'], inplace=False)
# or convert(net, prune(sparsity=0.5), activation_layers=[nn.ReLU], exclude=['second'], inplace=False)
Apply `prunesparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}` on the .first_half.1 activation Exclude .second_half.1 activation
Sequential( (first_half): Sequential( (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) (1): Sequential( (0): ReLU() (1): PruneLayer(sparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}) ) ) (second_half): Sequential( (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) (1): ReLU() ) )
5. Inserting pruning operator before all Conv2D layers¶
convert(net, prune(sparsity=0.5), activation_layers=[nn.Conv2d], order="pre", inplace=False)
Apply `prunesparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}` on the .first_half.0 activation Apply `prunesparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}` on the .second_half.0 activation
Sequential( (first_half): Sequential( (0): Sequential( (0): PruneLayer(sparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}) (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) ) (1): ReLU() ) (second_half): Sequential( (0): Sequential( (0): PruneLayer(sparsity=0.5, start=1000, interval=1000, repetition=4, dimensions={1}) (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) ) (1): ReLU() ) )
More Quantization¶
Symmetric Quantization with Scaler¶
The class ScalerQuantizer
implements the algorithm 3 in our MDPI paper. Similarly, the class DecimalQuantizer
shares the exact same implementation except the scaling factor is always restricted to be a power of 2. Their instances can be passed to the callback
argument of quantize
, like:
from qsparse.quantize import DecimalQuantizer
quantize(bits=8, callback=DecimalQuantizer())
QuantizeLayer(bits=8, timeout=1000, callback=DecimalQuantizer, channelwise=1)
The ScalerQuantizer
and DecimalQuantizer
includes the functions of both inference and parameters learning. To access only the inference function to quantize tensors, one can use functions quantize_with_scaler
and quantize_with_decimal
:
import torch
from qsparse.quantize import quantize_with_decimal, quantize_with_scaler
data = torch.rand(1000)
((data - quantize_with_decimal(data, bits=8, decimal=6))**2).mean(), ((data - quantize_with_scaler(data, bits=8, scaler=0.01))**2).mean()
(tensor(8.2643e-05), tensor(8.4282e-06))
Asymmetric Quantization¶
The class AdaptiveQuantizer
implements the algorithm 2 in our MDPI paper, which estimates the lower and upper bounds of incoming data streams and apply assymmetric quantization. Its inference function can be accessed from quantize_with_line
.
from qsparse.quantize import AdaptiveQuantizer
quantize(bits=8, callback=AdaptiveQuantizer())
QuantizeLayer(bits=8, timeout=1000, callback=AdaptiveQuantizer, channelwise=1)
from qsparse.quantize import quantize_with_line
((data - quantize_with_line(data, bits=8, lines=(0, 1)))**2).mean() # lines specify the (lower, upper) bounds.
tensor(1.2650e-06)
Channelwise Quantization¶
Channel-wise quantization denotes the technique to use different decimal bits cross different channels, i.e., quantize each channel independently. It is commonly known that channel-wise quantization can reduce quantization error drastically especially when inter-channel numerical ranges have large variance.
To specify channelwise quantization on dimension 1 (dimension 1 as channel):
quantize(bits=8, channelwise=1)
QuantizeLayer(bits=8, timeout=1000, callback=ScalerQuantizer, channelwise=1)
To disable channelwise quantization:
quantize(bits=8, channelwise=-1)
QuantizeLayer(bits=8, timeout=1000, callback=ScalerQuantizer, channelwise=-1)
Groupwise Quantization¶
Channelwise quantization allocates one set of scaling factor and zero-point for each channel, which could possibly complicate the inference implementation when both weight and activations are quantized channel-wisely, especially for networks with a large number of channels. Here, we provide a technique, which we name as groupwise quantization. Specifically, we cluster the channel-wise quantization parameters (scaling factor and zero-points) into groups, and share one set of quantization parameter within each group. We empirically find that groupwise quantization yields little to no performance drop compared to channelwise pruning, even with an extremely small group number, e.g. 4.
layer = quantize(bits=8, channelwise=1,
callback=DecimalQuantizer(group_num=4,
# `group_timeout` denotes the steps when the clustering starts after the activation of the quantization operator.
group_timeout=10), timeout=10)
for _ in range(21):
layer(torch.rand(1, 1024, 3, 3))
quantizing with 8 bits clustering 1024 channels into 4 groups
For a convolution layer with 1024 channels, using groupwise quantization with 4 groups produces a 256 times of reduction in the number of quantization parameters.
Quantization Bias¶
By default, for weight quantization, quantize will only quantize the weight parameter and leave the bias parameter to have full precision (Jacob et al.). The reason is that bias can be used to initialize the high precision accumulator for the mult-add operations. Bias can be quantized in QSPARSE by:
from qsparse import quantize
quantize(nn.Conv2d(1, 1, 1), bits=8, bias_bits=12)
Conv2d( 1, 1, kernel_size=(1, 1), stride=(1, 1) (quantize): QuantizeLayer(bits=8, timeout=1000, callback=ScalerQuantizer, channelwise=1) (quantize_bias): QuantizeLayer(bits=12, timeout=1000, callback=ScalerQuantizer, channelwise=0) )
Integer Arithmetic Verification¶
Here we provide an example to demonstrate floating-point simulated quantization can fully match with 8-bit integer arithmetic.
ni = 7 # input shift
no = 6 # output shift
input = torch.randint(-128, 127, size=(3, 10, 32, 32))
input_float = input.float() / 2 ** ni
Quantization computation simulated with floating-point:
timeout = 5
qconv = quantize(
torch.nn.Conv2d(10, 30, 3, bias=False), bits=8, timeout=timeout, channelwise=0, callback=DecimalQuantizer()
)
qconv.train()
for _ in range(timeout + 1): # ensure the quantization has been triggered
qconv(input_float)
output_float = quantize_with_decimal(qconv(input_float), 8, no)
quantizing with 8 bits
Reproduce the above computation in 8-bit arithmetic:
decimal = (1 / qconv.quantize.weight).nan_to_num(posinf=1, neginf=1).log2().round().int()
weight = qconv.weight * (2.0 ** decimal).view(-1, 1, 1, 1)
output_int = F.conv2d(input.int(), weight.int())
for i in range(output_int.shape[1]):
output_int[:, i] = (
output_int[:, i].float() / 2 ** (ni + decimal[i] - no)
).int()
diff = (
output_float.detach().numpy() - (output_int.float() / 2 ** no).detach().numpy()
)
assert np.all(diff == 0)
print("Fully match with integer arithmetic")
Fully match with integer arithmetic
Extras¶
Resuming from Checkpoint¶
Both quantize
and prune
layers support to resume training from a checkpoint. However, due to the fact that:
- QSPARSE determines the shape of its parameters (e.g.
scaling factor
,mask
) at the first forward pass. load_state_dict
currently does not allow shape mismatch (pytorch/issues#40859)
Therefore, we provide the preload_qsparse_state_dict
to be called before the load_state_dict
to mitigate the above issue.
from qsparse.util import preload_qsparse_state_dict
def make_conv():
return quantize(prune(nn.Conv2d(16, 32, 3),
sparsity=0.5, start=200,
interval=10, repetition=4),
bits=8, timeout=100)
conv = make_conv()
for _ in range(241):
conv(torch.rand(10, 16, 7, 7))
try:
conv2 = make_conv()
conv2.load_state_dict(conv.state_dict())
except RuntimeError as e:
print(f'\nCatch error as expected: {e}\n' )
conv3 = make_conv()
preload_qsparse_state_dict(conv3, conv.state_dict())
conv3.load_state_dict(conv.state_dict())
tensor = torch.rand(10, 16, 7, 7)
assert np.allclose(conv(tensor).detach().numpy(), conv3(tensor).detach().numpy(), atol=1e-6)
print('successfully loading from checkpoint')
quantizing with 8 bits [Prune] [Step 200] pruned 0.29 Start pruning at @ 200 [Prune] [Step 210] pruned 0.44 [Prune] [Step 220] pruned 0.49 [Prune] [Step 230] pruned 0.50 Catch error as expected: Error(s) in loading state_dict for Conv2d: Unexpected key(s) in state_dict: "prune.callback.magnitude", "quantize.weight", "quantize._n_updates". size mismatch for prune.mask: copying a param with shape torch.Size([1, 16, 1, 1]) from checkpoint, the shape in current model is torch.Size([]). successfully loading from checkpoint
Inspecting Parameters of a Pruned/Quantized Model¶
Parameters of a quantized and pruned networks can be easily inspected and therefore post-processed for use cases such as compiling for neural engines:
state_dict = conv.state_dict()
for k,v in state_dict.items():
print(k, v.numpy().shape)
weight (32, 16, 3, 3) bias (32,) prune.mask (1, 16, 1, 1) prune._n_updates (1,) prune._cur_sparsity (1,) prune.callback.t (1,) prune.callback.magnitude (1, 16, 1, 1) quantize.weight (16, 1) quantize._n_updates (1,)
Param | Description |
---|---|
quantize.weight |
scaling factors |
*._n_updates |
internal counter for number of training steps |
prune.mask |
binary mask for pruning |
prune._cur_sparsity |
internal variable to record current sparsity |
prune.callback.magnitude |
internal boolean variable to record whether quantization has been triggered. |