Skip to content

fuse

BNFuser (Protocol)

Type signature of the handers used in fuse_bn.

__call__(self, layer, bn) special

Fuse batch norm into the previous layer.

Parameters:

Name Type Description Default
layer nn.Module

layers like Conv2d, Linear, etc.

required
bn nn.Module

batch norm layer, could be BatchNorm2d, BatchNorm1d, etc.

required

Returns:

Type Description
nn.Module

fused layer

Source code in qsparse/fuse.py
14
15
16
17
18
19
20
21
22
23
def __call__(self, layer: nn.Module, bn: nn.Module) -> nn.Module:
    """Fuse batch norm into the previous layer.

    Args:
        layer (nn.Module): layers like Conv2d, Linear, etc.
        bn (nn.Module): batch norm layer, could be BatchNorm2d, BatchNorm1d, etc.

    Returns:
        nn.Module: fused layer
    """

conv2d_bn_fuser(conv, bn)

BNFuser for Conv2d

Source code in qsparse/fuse.py
26
27
28
29
30
31
32
33
34
35
36
37
38
def conv2d_bn_fuser(conv: nn.Module, bn: nn.Module) -> nn.Module:
    """BNFuser for Conv2d"""
    w = conv._parameters["weight"].detach()
    b = conv._parameters["bias"].detach() if conv.bias is not None else 0
    mean = bn.running_mean.detach()
    var_sqrt = torch.sqrt(bn.running_var.detach().add(1e-5))
    gamma = bn.weight.detach()
    beta = bn.bias.detach()
    new_weight = w * (gamma / var_sqrt)[:, None, None, None]
    new_bias = (b - mean) * gamma / var_sqrt + beta
    conv._parameters["weight"].data = new_weight
    conv._parameters["bias"] = nn.Parameter(new_bias)
    return conv

deconv2d_bn_fuser(deconv, bn)

BNFuser for ConvTranspose2d

Source code in qsparse/fuse.py
56
57
58
59
60
61
62
63
64
65
66
67
68
def deconv2d_bn_fuser(deconv: nn.Module, bn: nn.Module) -> nn.Module:
    """BNFuser for ConvTranspose2d"""
    w = deconv._parameters["weight"].detach()
    b = deconv._parameters["bias"].detach() if deconv.bias is not None else 0
    mean = bn.running_mean.detach()
    var_sqrt = torch.sqrt(bn.running_var.detach().add(1e-5))
    gamma = bn.weight.detach()
    beta = bn.bias.detach()
    new_weight = w * (gamma / var_sqrt)[None, :, None, None]
    new_bias = (b - mean) * gamma / var_sqrt + beta
    deconv._parameters["weight"].data = new_weight
    deconv._parameters["bias"] = nn.Parameter(new_bias)
    return deconv

fuse_bn(model, layers=['Conv2d', 'Linear', 'ConvTranspose2d'], handlers=None, log=True, inplace=True)

Fuse the batch norm layers back to the previous conv/deconv/linear layers in a newtwork.

Parameters:

Name Type Description Default
model nn.Module

network

required
layers Iterable[str]

[description]. Defaults to ["Conv2d", "Linear", "ConvTranspose2d"].

['Conv2d', 'Linear', 'ConvTranspose2d']
handlers Optional[Mapping[str, BNFuser]]

Mapping from layer type to BNFuser. Defaults to None, will use { Linear: linear_bn_fuser, Conv2d: conv2d_bn_fuser, ConvTranspose2d: deconv2d_bn_fuser }.

None
log bool

whether print the fuse log. Defaults to True.

True
inplace bool

whether mutates the original module. Defaults to False.

True

Returns:

Type Description
nn.Module

network with bn fused

Source code in qsparse/fuse.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def fuse_bn(  # noqa: C901
    model: nn.Module,
    layers: Iterable[str] = ["Conv2d", "Linear", "ConvTranspose2d"],
    handlers: Optional[Mapping[str, BNFuser]] = None,
    log: bool = True,
    inplace: bool = True,
) -> nn.Module:
    """Fuse the batch norm layers back to the previous conv/deconv/linear layers in a newtwork.

    Args:
        model (nn.Module): network
        layers (Iterable[str], optional): [description]. Defaults to ["Conv2d", "Linear", "ConvTranspose2d"].
        handlers (Optional[Mapping[str, BNFuser]], optional): Mapping from layer type to [BNFuser][qsparse.fuse.BNFuser]. Defaults to None, will use { Linear: [linear\_bn\_fuser][qsparse.fuse.linear_bn_fuser], Conv2d: [conv2d\_bn\_fuser][qsparse.fuse.conv2d_bn_fuser], ConvTranspose2d: [deconv2d\_bn\_fuser][qsparse.fuse.deconv2d_bn_fuser] }.
        log (bool, optional): whether print the fuse log. Defaults to True.
        inplace (bool, optional): whether mutates the original module. Defaults to False.

    Returns:
        nn.Module: network with bn fused
    """
    handlers = {**copy(default_handlers), **(handlers or {})}
    layers = set(layers)
    for name in layers:
        assert name in handlers, f"layer {name} is not in handlers"

    if not inplace:
        model = deepcopy(model)

    def is_bn(layer: nn.Module) -> bool:
        return layer.__class__.__name__.lower().startswith("batchnorm")

    def get_layer_type(layer: Optional[nn.Module]) -> str:
        if layer is None:
            return ""
        else:
            return layer.__class__.__name__

    def fuse_bn_sequential(
        seq: nn.Sequential, input: Optional[nn.Module] = None
    ) -> Tuple[nn.Module, Optional[nn.Module]]:
        sequence = []

        def get_prev_layer():
            return sequence[-1] if len(sequence) > 0 else input

        for layer in seq.children():
            if is_bn(layer):
                bn = layer
                operation = get_prev_layer()
                layer_type = get_layer_type(operation)
                if layer_type in layers:
                    if log:
                        logging.info(f"Fuse {bn} into {operation}")
                    operation = handlers[layer_type](operation, bn)
                    if len(sequence) > 0:
                        sequence[-1] = operation
                    else:
                        input = operation
                else:
                    sequence.append(bn)
            elif isinstance(layer, nn.Sequential):
                layer, prev_layer = fuse_bn_sequential(layer, get_prev_layer())
                if prev_layer is not None:
                    if len(sequence) > 0:
                        sequence[-1] = prev_layer
                    else:
                        input = prev_layer
                if layer is not None:
                    sequence.append(layer)
            else:
                sequence.append(layer)
        if len(sequence) == 0:
            return None, input
        elif len(sequence) == 1:
            return sequence[0], input
        else:
            return nn.Sequential(*sequence), input

    if isinstance(nn_module(model), nn.Sequential):
        _model = fuse_bn_sequential(nn_module(model))[0]
        if model == nn_module(model):
            model = _model
        else:
            model.module = _model
    else:
        for name, m in nn_module(model).named_children():
            if isinstance(m, nn.Sequential):
                nn_module(model)._modules[name] = fuse_bn_sequential(m)[0]
    return model

linear_bn_fuser(linear, bn)

BNFuser for Linear

Source code in qsparse/fuse.py
41
42
43
44
45
46
47
48
49
50
51
52
53
def linear_bn_fuser(linear: nn.Module, bn: nn.Module) -> nn.Module:
    """BNFuser for Linear"""
    w = linear._parameters["weight"].detach()
    b = linear._parameters["bias"].detach() if linear.bias is not None else 0
    mean = bn.running_mean.detach()
    var_sqrt = torch.sqrt(bn.running_var.detach().add(1e-5))
    gamma = bn.weight.detach()
    beta = bn.bias.detach()
    new_weight = w * (gamma / var_sqrt)[:, None]
    new_bias = (b - mean) * gamma / var_sqrt + beta
    linear._parameters["weight"].data = new_weight
    linear._parameters["bias"] = nn.Parameter(new_bias)
    return linear