Skip to content

sparse

MagnitudePruningCallback (Module)

__init__(self, mask_refresh_interval=-1, stop_mask_refresh=inf, use_gradient=False, running_average=True, l0=False, forward_hook=None) special

Magnitude-based pruning function as the callback of prune.

Parameters:

Name Type Description Default
mask_refresh_interval int

number of steps to refresh mask. Defaults to 1.

-1
stop_mask_refresh int

when to stop refreshing mask. Defaults to float('inf').

inf
use_gradient bool

whether use the magnitude of gradients

False
running_average bool

whether use the running average of magnitude. Defaults to True.

True
l0 bool

whether to use l0 magnitude instead of l0

False
forward_hook Callable

callback function that gets executed at each forward. Defaults to None.

None
Source code in qsparse/sparse.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    mask_refresh_interval: int = -1,
    stop_mask_refresh: int = float("inf"),
    use_gradient: bool = False,
    running_average: bool = True,
    l0: bool = False,
    forward_hook: Callable[[torch.Tensor, str], None] = None
):
    """
    Magnitude-based pruning function  as the callback of [prune][qsparse.sparse.prune].

    Args:
        mask_refresh_interval (int, optional): number of steps to refresh mask. Defaults to 1.
        stop_mask_refresh (int, optional): when to stop refreshing mask. Defaults to float('inf').
        use_gradient (bool, optional): whether use the magnitude of gradients
        running_average (bool, optional): whether use the running average of magnitude. Defaults to True.
        l0 (bool, optional): whether to use l0 magnitude instead of l0
        forward_hook (Callable, optional): callback function that gets executed at each forward. Defaults to None.
    """
    super().__init__()
    self.mask_refresh_interval = mask_refresh_interval
    self.stop_mask_refresh = stop_mask_refresh
    self.use_gradient = use_gradient
    self.t = nn.Parameter(torch.full((1,), -1), requires_grad=False)
    if use_gradient and not running_average:
        raise ArgumentError(
            "the combination of `use_gradient=True` and `running_average=False` is not supported"
        )
    self.running_average = running_average
    self.prev_grad_hook = None
    self.l0 = l0
    self.forward_hook = forward_hook

forward(self, x, sparsity, mask, name='')

Defines the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in qsparse/sparse.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def forward(self, x: torch.Tensor, sparsity: float, mask: torch.Tensor, name=""):
    if self.training:
        if not self.initted:
            self.initialize(mask)
            self.t.data[:] = 0
            if self.mask_refresh_interval <= 0:
                self.mask_refresh_interval = 1

        t_item = self.t.item()
        if t_item < self.stop_mask_refresh:
            self.receive_input(x)
        if (
            sparsity >= 0
            and (t_item % self.mask_refresh_interval == 0 and t_item <= self.stop_mask_refresh ) and (t_item > 0 or not self.running_average) 
        ):
            out = self.prune_and_update_mask(x, sparsity, mask)
        else:
            out = x * mask
        self.t += 1
        if self.forward_hook is not None:
            self.forward_hook(mask, name)
        return out
    else:
        return x * mask

PruneLayer (Module)

Applies pruning over input tensor. Please look for detailed description in prune

initted: bool property readonly

whether the parameters of the prune layer are initialized.

forward(self, x)

Prunes input tensor according to given sparsification schedule.

Parameters:

Name Type Description Default
x torch.Tensor

tensor to be pruned

required

Exceptions:

Type Description
RuntimeError

when the shape of input tensors mismatch with the shape of binary mask

Returns:

Type Description
torch.Tensor

pruned tensor

Source code in qsparse/sparse.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Prunes input tensor according to given sparsification schedule.

    Args:
        x (torch.Tensor): tensor to be pruned

    Raises:
        RuntimeError: when the shape of input tensors mismatch with the shape of binary mask

    Returns:
        torch.Tensor: pruned tensor
    """

    if not self.initted:
        assert len(x.shape) > 1
        with torch.no_grad():
            mask_shape = [1 if i not in self.dimensions else s
                          for i, s in enumerate(list(x.shape))]
            self.mask = nn.Parameter(
                torch.ones(
                    *mask_shape,
                    dtype=torch.bool,
                ).to(x.device),
                requires_grad=False,
            )
            if self.mask.numel() == 1:
                logging.warn(f"the mask shape of {self.name} is {tuple(self.mask.shape)}, which is not prunable")

        self._n_updates = nn.Parameter(
            torch.zeros(1, dtype=torch.int).to(x.device),
            requires_grad=False,
        )
        self._cur_sparsity = nn.Parameter(
            torch.zeros(1).to(x.device), requires_grad=False
        )

    if (self._n_updates.item() in self.schedules) and self.training:
        ratio = (
            1.0
            - (self._n_updates.item() - self.start + self.rampup_interval)
            / (self.interval * self.repetition)
        ) ** 3
        self._cur_sparsity[0] = self.sparsity * (1 - ratio)
        logging.warning(
            f"[Prune{self.name if self.name == '' else f' @ {self.name}'}] [Step {self._n_updates.item()}] pruned {self._cur_sparsity.item():.02f}"
        )

    if not self.training or self.mask.numel() == 1:
        out = x * self.mask
    else:
        n_updates = self._n_updates.item()
        if n_updates >= self.start:
            if n_updates == self.start:
                logging.warning(f"Start pruning at {self.name} @ {n_updates}")
            out = self.callback(x, self._cur_sparsity.item(), mask=self.mask, name=self.name)
        else:
            out = x
        self._n_updates += 1
    return out

UniformPruningCallback (MagnitudePruningCallback)

unstructured uniform pruning function.

This function will prune uniformly without considering magnitude of the input tensors. If a init mask is provided, it will not reactivate those already pruned locations in init mask.

prune(inp=None, sparsity=0.5, dimensions={1}, callback=None, start=1000, interval=1000, repetition=4, rampup=False, name='')

Creates a PruneLayer which is usually used for feature pruning if no input module is provided, or creates a weight- pruned version of the input module.

Parameters:

Name Type Description Default
inp nn.Module

input module whose weight is to be pruned. Defaults to None.

None
sparsity float

target sparsity. Defaults to 0.5.

0.5
dimensions Iterable[int]

which dimensions to prune. Defaults to {1}, pruning the channel dimension of conv feature map.

{1}
callback MagnitudePruningCallback

callback for actual operation of calculating binary mask and prune inputs. Defaults to MagnitudePruningCallback.

None
start int

starting step to apply pruning. Defaults to 1000.

1000
interval int

interval of iterations between each sparsity increasing steps. Defaults to 1000.

1000
repetition int

number of sparsity increasing steps. Defaults to 4.

4
rampup bool

whether to wait another interval before starting to prune. Defaults to False.

False
name str

name of the prune layer created, used for better logging. Defaults to "".

''

Returns:

Type Description
nn.Module

input module with its weight pruned or a instance of PruneLayer for feature pruning

Source code in qsparse/sparse.py
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def prune(
    inp: nn.Module = None,
    sparsity: float = 0.5,
    dimensions: Iterable[int] = {1},
    callback: MagnitudePruningCallback = None,
    # step-wise pruning parameters
    start: int = 1000,
    interval: int = 1000,
    repetition: int = 4,
    rampup: bool = False,
    name=""
) -> nn.Module:
    """Creates a [PruneLayer][qsparse.sparse.PruneLayer] which is usually used
    for feature pruning if no input module is provided, or creates a weight-
    pruned version of the input module.

    Args:
        inp (nn.Module, optional): input module whose weight is to be pruned. Defaults to None.
        sparsity (float, optional): target sparsity. Defaults to 0.5.
        dimensions (Iterable[int]): which dimensions to prune. Defaults to {1}, pruning the channel dimension of conv feature map.
        callback (MagnitudePruningCallback, optional): callback for actual operation of calculating binary mask and prune inputs. Defaults to [MagnitudePruningCallback][qsparse.sparse.MagnitudePruningCallback].
        start (int, optional): starting step to apply pruning. Defaults to 1000.
        interval (int, optional): interval of iterations between each sparsity increasing steps. Defaults to 1000.
        repetition (int, optional): number of sparsity increasing steps. Defaults to 4.
        rampup (bool, optional): whether to wait another interval before starting to prune. Defaults to False.
        name (str, optional): name of the prune layer created, used for better logging. Defaults to "".

    Returns:
        nn.Module: input module with its weight pruned or a instance of [PruneLayer][qsparse.sparse.PruneLayer] for feature pruning
    """
    callback = callback or MagnitudePruningCallback()

    kwargs = dict(
           start=int(start),
            sparsity=sparsity,
            interval=int(interval),
            repetition=repetition,
            rampup=rampup,
            name=name,
            callback=callback,
            dimensions=dimensions
    )

    def get_prune_layer(
    ):
        return PruneLayer(
            start=int(start),
            sparsity=sparsity,
            interval=int(interval),
            repetition=repetition,
            rampup=rampup,
            name=name,
            callback=callback,
            dimensions=dimensions
        )

    if inp is None:
        layer = get_prune_layer()
        setattr(layer, "_kwargs", kwargs)
        return layer
    elif isinstance(inp, nn.Module):
        return imitate(inp, "prune", get_prune_layer())
    else:
        raise ValueError(f"{inp} is not a valid argument for prune")