-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy pathprune.py
1372 lines (1126 loc) · 56.4 KB
/
prune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# mypy: allow-untyped-defs
r"""Pruning methods."""
import numbers
from abc import ABC, abstractmethod
from collections.abc import Iterable
import torch
class BasePruningMethod(ABC):
r"""Abstract base class for creation of new pruning techniques.
Provides a skeleton for customization requiring the overriding of methods
such as :meth:`compute_mask` and :meth:`apply`.
"""
_tensor_name: str
def __call__(self, module, inputs):
r"""Multiply the mask into original tensor and store the result.
Multiplies the mask (stored in ``module[name + '_mask']``)
into the original tensor (stored in ``module[name + '_orig']``)
and stores the result into ``module[name]`` by using :meth:`apply_mask`.
Args:
module (nn.Module): module containing the tensor to prune
inputs: not used.
"""
setattr(module, self._tensor_name, self.apply_mask(module))
@abstractmethod
def compute_mask(self, t, default_mask):
r"""Compute and returns a mask for the input tensor ``t``.
Starting from a base ``default_mask`` (which should be a mask of ones
if the tensor has not been pruned yet), generate a random mask to
apply on top of the ``default_mask`` according to the specific pruning
method recipe.
Args:
t (torch.Tensor): tensor representing the importance scores of the
parameter to prune.
default_mask (torch.Tensor): Base mask from previous pruning
iterations, that need to be respected after the new mask is
applied. Same dims as ``t``.
Returns:
mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
"""
def apply_mask(self, module):
r"""Simply handles the multiplication between the parameter being pruned and the generated mask.
Fetches the mask and the original tensor from the module
and returns the pruned version of the tensor.
Args:
module (nn.Module): module containing the tensor to prune
Returns:
pruned_tensor (torch.Tensor): pruned version of the input tensor
"""
# to carry out the multiplication, the mask needs to have been computed,
# so the pruning method must know what tensor it's operating on
assert (
self._tensor_name is not None
), f"Module {module} has to be pruned" # this gets set in apply()
mask = getattr(module, self._tensor_name + "_mask")
orig = getattr(module, self._tensor_name + "_orig")
pruned_tensor = mask.to(dtype=orig.dtype) * orig
return pruned_tensor
@classmethod
def apply(cls, module, name, *args, importance_scores=None, **kwargs):
r"""Add pruning on the fly and reparametrization of a tensor.
Adds the forward pre-hook that enables pruning on the fly and
the reparametrization of a tensor in terms of the original tensor
and the pruning mask.
Args:
module (nn.Module): module containing the tensor to prune
name (str): parameter name within ``module`` on which pruning
will act.
args: arguments passed on to a subclass of
:class:`BasePruningMethod`
importance_scores (torch.Tensor): tensor of importance scores (of
same shape as module parameter) used to compute mask for pruning.
The values in this tensor indicate the importance of the
corresponding elements in the parameter being pruned.
If unspecified or None, the parameter will be used in its place.
kwargs: keyword arguments passed on to a subclass of a
:class:`BasePruningMethod`
"""
def _get_composite_method(cls, module, name, *args, **kwargs):
# Check if a pruning method has already been applied to
# `module[name]`. If so, store that in `old_method`.
old_method = None
found = 0
# there should technically be only 1 hook with hook.name == name
# assert this using `found`
hooks_to_remove = []
for k, hook in module._forward_pre_hooks.items():
# if it exists, take existing thing, remove hook, then
# go through normal thing
if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
old_method = hook
hooks_to_remove.append(k)
found += 1
assert (
found <= 1
), f"Avoid adding multiple pruning hooks to the\
same tensor {name} of module {module}. Use a PruningContainer."
for k in hooks_to_remove:
del module._forward_pre_hooks[k]
# Apply the new pruning method, either from scratch or on top of
# the previous one.
method = cls(*args, **kwargs) # new pruning
# Have the pruning method remember what tensor it's been applied to
method._tensor_name = name
# combine `methods` with `old_method`, if `old_method` exists
if old_method is not None: # meaning that there was a hook
# if the hook is already a pruning container, just add the
# new pruning method to the container
if isinstance(old_method, PruningContainer):
old_method.add_pruning_method(method)
method = old_method # rename old_method --> method
# if the hook is simply a single pruning method, create a
# container, add the old pruning method and the new one
elif isinstance(old_method, BasePruningMethod):
container = PruningContainer(old_method)
# Have the pruning method remember the name of its tensor
# setattr(container, '_tensor_name', name)
container.add_pruning_method(method)
method = container # rename container --> method
return method
method = _get_composite_method(cls, module, name, *args, **kwargs)
# at this point we have no forward_pre_hooks but we could have an
# active reparametrization of the tensor if another pruning method
# had been applied (in which case `method` would be a PruningContainer
# and not a simple pruning method).
# Pruning is to be applied to the module's tensor named `name`,
# starting from the state it is found in prior to this iteration of
# pruning. The pruning mask is calculated based on importances scores.
orig = getattr(module, name)
if importance_scores is not None:
assert (
importance_scores.shape == orig.shape
), f"importance_scores should have the same shape as parameter {name} of {module}"
else:
importance_scores = orig
# If this is the first time pruning is applied, take care of moving
# the original tensor to a new parameter called name + '_orig' and
# and deleting the original parameter
if not isinstance(method, PruningContainer):
# copy `module[name]` to `module[name + '_orig']`
module.register_parameter(name + "_orig", orig)
# temporarily delete `module[name]`
del module._parameters[name]
default_mask = torch.ones_like(orig) # temp
# If this is not the first time pruning is applied, all of the above
# has been done before in a previous pruning iteration, so we're good
# to go
else:
default_mask = (
getattr(module, name + "_mask")
.detach()
.clone(memory_format=torch.contiguous_format)
)
# Use try/except because if anything goes wrong with the mask
# computation etc., you'd want to roll back.
try:
# get the final mask, computed according to the specific method
mask = method.compute_mask(importance_scores, default_mask=default_mask)
# reparameterize by saving mask to `module[name + '_mask']`...
module.register_buffer(name + "_mask", mask)
# ... and the new pruned tensor to `module[name]`
setattr(module, name, method.apply_mask(module))
# associate the pruning method to the module via a hook to
# compute the function before every forward() (compile by run)
module.register_forward_pre_hook(method)
except Exception as e:
if not isinstance(method, PruningContainer):
orig = getattr(module, name + "_orig")
module.register_parameter(name, orig)
del module._parameters[name + "_orig"]
raise e
return method
def prune(self, t, default_mask=None, importance_scores=None):
r"""Compute and returns a pruned version of input tensor ``t``.
According to the pruning rule specified in :meth:`compute_mask`.
Args:
t (torch.Tensor): tensor to prune (of same dimensions as
``default_mask``).
importance_scores (torch.Tensor): tensor of importance scores (of
same shape as ``t``) used to compute mask for pruning ``t``.
The values in this tensor indicate the importance of the
corresponding elements in the ``t`` that is being pruned.
If unspecified or None, the tensor ``t`` will be used in its place.
default_mask (torch.Tensor, optional): mask from previous pruning
iteration, if any. To be considered when determining what
portion of the tensor that pruning should act on. If None,
default to a mask of ones.
Returns:
pruned version of tensor ``t``.
"""
if importance_scores is not None:
assert (
importance_scores.shape == t.shape
), "importance_scores should have the same shape as tensor t"
else:
importance_scores = t
default_mask = default_mask if default_mask is not None else torch.ones_like(t)
return t * self.compute_mask(importance_scores, default_mask=default_mask)
def remove(self, module):
r"""Remove the pruning reparameterization from a module.
The pruned parameter named ``name`` remains permanently pruned,
and the parameter named ``name+'_orig'`` is removed from the parameter list.
Similarly, the buffer named ``name+'_mask'`` is removed from the buffers.
Note:
Pruning itself is NOT undone or reversed!
"""
# before removing pruning from a tensor, it has to have been applied
assert (
self._tensor_name is not None
), f"Module {module} has to be pruned before pruning can be removed" # this gets set in apply()
# to update module[name] to latest trained weights
weight = self.apply_mask(module) # masked weights
# delete and reset
if hasattr(module, self._tensor_name):
delattr(module, self._tensor_name)
orig = module._parameters[self._tensor_name + "_orig"]
orig.data = weight.data
del module._parameters[self._tensor_name + "_orig"]
del module._buffers[self._tensor_name + "_mask"]
setattr(module, self._tensor_name, orig)
class PruningContainer(BasePruningMethod):
"""Container holding a sequence of pruning methods for iterative pruning.
Keeps track of the order in which pruning methods are applied and handles
combining successive pruning calls.
Accepts as argument an instance of a BasePruningMethod or an iterable of
them.
"""
def __init__(self, *args):
self._pruning_methods: tuple[BasePruningMethod, ...] = ()
if not isinstance(args, Iterable): # only 1 item
self._tensor_name = args._tensor_name
self.add_pruning_method(args)
elif len(args) == 1: # only 1 item in a tuple
self._tensor_name = args[0]._tensor_name
self.add_pruning_method(args[0])
else: # manual construction from list or other iterable (or no args)
for method in args:
self.add_pruning_method(method)
def add_pruning_method(self, method):
r"""Add a child pruning ``method`` to the container.
Args:
method (subclass of BasePruningMethod): child pruning method
to be added to the container.
"""
# check that we're adding a pruning method to the container
if not isinstance(method, BasePruningMethod) and method is not None:
raise TypeError(f"{type(method)} is not a BasePruningMethod subclass")
elif method is not None and self._tensor_name != method._tensor_name:
raise ValueError(
"Can only add pruning methods acting on "
f"the parameter named '{self._tensor_name}' to PruningContainer {self}."
+ f" Found '{method._tensor_name}'"
)
# if all checks passed, add to _pruning_methods tuple
self._pruning_methods += (method,) # type: ignore[operator]
def __len__(self):
return len(self._pruning_methods)
def __iter__(self):
return iter(self._pruning_methods)
def __getitem__(self, idx):
return self._pruning_methods[idx]
def compute_mask(self, t, default_mask):
r"""Apply the latest ``method`` by computing the new partial masks and returning its combination with the ``default_mask``.
The new partial mask should be computed on the entries or channels
that were not zeroed out by the ``default_mask``.
Which portions of the tensor ``t`` the new mask will be calculated from
depends on the ``PRUNING_TYPE`` (handled by the type handler):
* for 'unstructured', the mask will be computed from the raveled
list of nonmasked entries;
* for 'structured', the mask will be computed from the nonmasked
channels in the tensor;
* for 'global', the mask will be computed across all entries.
Args:
t (torch.Tensor): tensor representing the parameter to prune
(of same dimensions as ``default_mask``).
default_mask (torch.Tensor): mask from previous pruning iteration.
Returns:
mask (torch.Tensor): new mask that combines the effects
of the ``default_mask`` and the new mask from the current
pruning ``method`` (of same dimensions as ``default_mask`` and
``t``).
"""
def _combine_masks(method, t, mask):
r"""Combine the masks from all pruning methods and returns a new mask.
Args:
method (a BasePruningMethod subclass): pruning method
currently being applied.
t (torch.Tensor): tensor representing the parameter to prune
(of same dimensions as mask).
mask (torch.Tensor): mask from previous pruning iteration
Returns:
new_mask (torch.Tensor): new mask that combines the effects
of the old mask and the new mask from the current
pruning method (of same dimensions as mask and t).
"""
new_mask = mask # start off from existing mask
new_mask = new_mask.to(dtype=t.dtype)
# compute a slice of t onto which the new pruning method will operate
if method.PRUNING_TYPE == "unstructured":
# prune entries of t where the mask is 1
slc = mask == 1
# for struct pruning, exclude channels that have already been
# entirely pruned
elif method.PRUNING_TYPE == "structured":
if not hasattr(method, "dim"):
raise AttributeError(
"Pruning methods of PRUNING_TYPE "
'"structured" need to have the attribute `dim` defined.'
)
# find the channels to keep by removing the ones that have been
# zeroed out already (i.e. where sum(entries) == 0)
n_dims = t.dim() # "is this a 2D tensor? 3D? ..."
dim = method.dim
# convert negative indexing
if dim < 0:
dim = n_dims + dim
# if dim is still negative after subtracting it from n_dims
if dim < 0:
raise IndexError(
f"Index is out of bounds for tensor with dimensions {n_dims}"
)
# find channels along dim = dim that aren't already tots 0ed out
keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0
# create slice to identify what to prune
slc = [slice(None)] * n_dims
slc[dim] = keep_channel
elif method.PRUNING_TYPE == "global":
n_dims = len(t.shape) # "is this a 2D tensor? 3D? ..."
slc = [slice(None)] * n_dims
else:
raise ValueError(f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}")
# compute the new mask on the unpruned slice of the tensor t
partial_mask = method.compute_mask(t[slc], default_mask=mask[slc])
new_mask[slc] = partial_mask.to(dtype=new_mask.dtype)
return new_mask
method = self._pruning_methods[-1]
mask = _combine_masks(method, t, default_mask)
return mask
class Identity(BasePruningMethod):
r"""Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones."""
PRUNING_TYPE = "unstructured"
def compute_mask(self, t, default_mask):
mask = default_mask
return mask
@classmethod
def apply(cls, module, name):
r"""Add pruning on the fly and reparametrization of a tensor.
Adds the forward pre-hook that enables pruning on the fly and
the reparametrization of a tensor in terms of the original tensor
and the pruning mask.
Args:
module (nn.Module): module containing the tensor to prune
name (str): parameter name within ``module`` on which pruning
will act.
"""
return super().apply(module, name)
class RandomUnstructured(BasePruningMethod):
r"""Prune (currently unpruned) units in a tensor at random.
Args:
name (str): parameter name within ``module`` on which pruning
will act.
amount (int or float): quantity of parameters to prune.
If ``float``, should be between 0.0 and 1.0 and represent the
fraction of parameters to prune. If ``int``, it represents the
absolute number of parameters to prune.
"""
PRUNING_TYPE = "unstructured"
def __init__(self, amount):
# Check range of validity of pruning amount
_validate_pruning_amount_init(amount)
self.amount = amount
def compute_mask(self, t, default_mask):
# Check that the amount of units to prune is not > than the number of
# parameters in t
tensor_size = t.nelement()
# Compute number of units to prune: amount if int,
# else amount * tensor_size
nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
# This should raise an error if the number of units to prune is larger
# than the number of units in the tensor
_validate_pruning_amount(nparams_toprune, tensor_size)
mask = default_mask.clone(memory_format=torch.contiguous_format)
if nparams_toprune != 0: # k=0 not supported by torch.kthvalue
prob = torch.rand_like(t)
topk = torch.topk(prob.view(-1), k=nparams_toprune)
mask.view(-1)[topk.indices] = 0
return mask
@classmethod
def apply(cls, module, name, amount):
r"""Add pruning on the fly and reparametrization of a tensor.
Adds the forward pre-hook that enables pruning on the fly and
the reparametrization of a tensor in terms of the original tensor
and the pruning mask.
Args:
module (nn.Module): module containing the tensor to prune
name (str): parameter name within ``module`` on which pruning
will act.
amount (int or float): quantity of parameters to prune.
If ``float``, should be between 0.0 and 1.0 and represent the
fraction of parameters to prune. If ``int``, it represents the
absolute number of parameters to prune.
"""
return super().apply(module, name, amount=amount)
class L1Unstructured(BasePruningMethod):
r"""Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm.
Args:
amount (int or float): quantity of parameters to prune.
If ``float``, should be between 0.0 and 1.0 and represent the
fraction of parameters to prune. If ``int``, it represents the
absolute number of parameters to prune.
"""
PRUNING_TYPE = "unstructured"
def __init__(self, amount):
# Check range of validity of pruning amount
_validate_pruning_amount_init(amount)
self.amount = amount
def compute_mask(self, t, default_mask):
# Check that the amount of units to prune is not > than the number of
# parameters in t
tensor_size = t.nelement()
# Compute number of units to prune: amount if int,
# else amount * tensor_size
nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
# This should raise an error if the number of units to prune is larger
# than the number of units in the tensor
_validate_pruning_amount(nparams_toprune, tensor_size)
mask = default_mask.clone(memory_format=torch.contiguous_format)
if nparams_toprune != 0: # k=0 not supported by torch.kthvalue
# largest=True --> top k; largest=False --> bottom k
# Prune the smallest k
topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False)
# topk will have .indices and .values
mask.view(-1)[topk.indices] = 0
return mask
@classmethod
def apply(cls, module, name, amount, importance_scores=None):
r"""Add pruning on the fly and reparametrization of a tensor.
Adds the forward pre-hook that enables pruning on the fly and
the reparametrization of a tensor in terms of the original tensor
and the pruning mask.
Args:
module (nn.Module): module containing the tensor to prune
name (str): parameter name within ``module`` on which pruning
will act.
amount (int or float): quantity of parameters to prune.
If ``float``, should be between 0.0 and 1.0 and represent the
fraction of parameters to prune. If ``int``, it represents the
absolute number of parameters to prune.
importance_scores (torch.Tensor): tensor of importance scores (of same
shape as module parameter) used to compute mask for pruning.
The values in this tensor indicate the importance of the corresponding
elements in the parameter being pruned.
If unspecified or None, the module parameter will be used in its place.
"""
return super().apply(
module, name, amount=amount, importance_scores=importance_scores
)
class RandomStructured(BasePruningMethod):
r"""Prune entire (currently unpruned) channels in a tensor at random.
Args:
amount (int or float): quantity of parameters to prune.
If ``float``, should be between 0.0 and 1.0 and represent the
fraction of parameters to prune. If ``int``, it represents the
absolute number of parameters to prune.
dim (int, optional): index of the dim along which we define
channels to prune. Default: -1.
"""
PRUNING_TYPE = "structured"
def __init__(self, amount, dim=-1):
# Check range of validity of amount
_validate_pruning_amount_init(amount)
self.amount = amount
self.dim = dim
def compute_mask(self, t, default_mask):
r"""Compute and returns a mask for the input tensor ``t``.
Starting from a base ``default_mask`` (which should be a mask of ones
if the tensor has not been pruned yet), generate a random mask to
apply on top of the ``default_mask`` by randomly zeroing out channels
along the specified dim of the tensor.
Args:
t (torch.Tensor): tensor representing the parameter to prune
default_mask (torch.Tensor): Base mask from previous pruning
iterations, that need to be respected after the new mask is
applied. Same dims as ``t``.
Returns:
mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
Raises:
IndexError: if ``self.dim >= len(t.shape)``
"""
# Check that tensor has structure (i.e. more than 1 dimension) such
# that the concept of "channels" makes sense
_validate_structured_pruning(t)
# Check that self.dim is a valid dim to index t, else raise IndexError
_validate_pruning_dim(t, self.dim)
# Check that the amount of channels to prune is not > than the number of
# channels in t along the dim to prune
tensor_size = t.shape[self.dim]
# Compute number of units to prune: amount if int,
# else amount * tensor_size
nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
# This should raise an error if the number of units to prune is larger
# than the number of units in the tensor
_validate_pruning_amount(nparams_toprune, tensor_size)
# Compute binary mask by initializing it to all 0s and then filling in
# 1s wherever topk.indices indicates, along self.dim.
# mask has the same shape as tensor t
def make_mask(t, dim, nchannels, nchannels_toprune):
# generate a random number in [0, 1] to associate to each channel
prob = torch.rand(nchannels)
# generate mask for each channel by 0ing out the channels that
# got assigned the k = nchannels_toprune lowest values in prob
threshold = torch.kthvalue(prob, k=nchannels_toprune).values
channel_mask = prob > threshold
mask = torch.zeros_like(t)
slc = [slice(None)] * len(t.shape)
slc[dim] = channel_mask
mask[slc] = 1
return mask
if nparams_toprune == 0: # k=0 not supported by torch.kthvalue
mask = default_mask
else:
# apply the new structured mask on top of prior (potentially
# unstructured) mask
mask = make_mask(t, self.dim, tensor_size, nparams_toprune)
mask *= default_mask.to(dtype=mask.dtype)
return mask
@classmethod
def apply(cls, module, name, amount, dim=-1):
r"""Add pruning on the fly and reparametrization of a tensor.
Adds the forward pre-hook that enables pruning on the fly and
the reparametrization of a tensor in terms of the original tensor
and the pruning mask.
Args:
module (nn.Module): module containing the tensor to prune
name (str): parameter name within ``module`` on which pruning
will act.
amount (int or float): quantity of parameters to prune.
If ``float``, should be between 0.0 and 1.0 and represent the
fraction of parameters to prune. If ``int``, it represents the
absolute number of parameters to prune.
dim (int, optional): index of the dim along which we define
channels to prune. Default: -1.
"""
return super().apply(module, name, amount=amount, dim=dim)
class LnStructured(BasePruningMethod):
r"""Prune entire (currently unpruned) channels in a tensor based on their L\ ``n``-norm.
Args:
amount (int or float): quantity of channels to prune.
If ``float``, should be between 0.0 and 1.0 and represent the
fraction of parameters to prune. If ``int``, it represents the
absolute number of parameters to prune.
n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
entries for argument ``p`` in :func:`torch.norm`.
dim (int, optional): index of the dim along which we define
channels to prune. Default: -1.
"""
PRUNING_TYPE = "structured"
def __init__(self, amount, n, dim=-1):
# Check range of validity of amount
_validate_pruning_amount_init(amount)
self.amount = amount
self.n = n
self.dim = dim
def compute_mask(self, t, default_mask):
r"""Compute and returns a mask for the input tensor ``t``.
Starting from a base ``default_mask`` (which should be a mask of ones
if the tensor has not been pruned yet), generate a mask to apply on
top of the ``default_mask`` by zeroing out the channels along the
specified dim with the lowest L\ ``n``-norm.
Args:
t (torch.Tensor): tensor representing the parameter to prune
default_mask (torch.Tensor): Base mask from previous pruning
iterations, that need to be respected after the new mask is
applied. Same dims as ``t``.
Returns:
mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
Raises:
IndexError: if ``self.dim >= len(t.shape)``
"""
# Check that tensor has structure (i.e. more than 1 dimension) such
# that the concept of "channels" makes sense
_validate_structured_pruning(t)
# Check that self.dim is a valid dim to index t, else raise IndexError
_validate_pruning_dim(t, self.dim)
# Check that the amount of channels to prune is not > than the number of
# channels in t along the dim to prune
tensor_size = t.shape[self.dim]
# Compute number of units to prune: amount if int,
# else amount * tensor_size
nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
nparams_tokeep = tensor_size - nparams_toprune
# This should raise an error if the number of units to prune is larger
# than the number of units in the tensor
_validate_pruning_amount(nparams_toprune, tensor_size)
# Structured pruning prunes entire channels so we need to know the
# L_n norm along each channel to then find the topk based on this
# metric
norm = _compute_norm(t, self.n, self.dim)
# largest=True --> top k; largest=False --> bottom k
# Keep the largest k channels along dim=self.dim
topk = torch.topk(norm, k=nparams_tokeep, largest=True)
# topk will have .indices and .values
# Compute binary mask by initializing it to all 0s and then filling in
# 1s wherever topk.indices indicates, along self.dim.
# mask has the same shape as tensor t
def make_mask(t, dim, indices):
# init mask to 0
mask = torch.zeros_like(t)
# e.g.: slc = [None, None, None], if len(t.shape) = 3
slc = [slice(None)] * len(t.shape)
# replace a None at position=dim with indices
# e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3]
slc[dim] = indices
# use slc to slice mask and replace all its entries with 1s
# e.g.: mask[:, :, [0, 2, 3]] = 1
mask[slc] = 1
return mask
if nparams_toprune == 0: # k=0 not supported by torch.kthvalue
mask = default_mask
else:
mask = make_mask(t, self.dim, topk.indices)
mask *= default_mask.to(dtype=mask.dtype)
return mask
@classmethod
def apply(cls, module, name, amount, n, dim, importance_scores=None):
r"""Add pruning on the fly and reparametrization of a tensor.
Adds the forward pre-hook that enables pruning on the fly and
the reparametrization of a tensor in terms of the original tensor
and the pruning mask.
Args:
module (nn.Module): module containing the tensor to prune
name (str): parameter name within ``module`` on which pruning
will act.
amount (int or float): quantity of parameters to prune.
If ``float``, should be between 0.0 and 1.0 and represent the
fraction of parameters to prune. If ``int``, it represents the
absolute number of parameters to prune.
n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
entries for argument ``p`` in :func:`torch.norm`.
dim (int): index of the dim along which we define channels to
prune.
importance_scores (torch.Tensor): tensor of importance scores (of same
shape as module parameter) used to compute mask for pruning.
The values in this tensor indicate the importance of the corresponding
elements in the parameter being pruned.
If unspecified or None, the module parameter will be used in its place.
"""
return super().apply(
module,
name,
amount=amount,
n=n,
dim=dim,
importance_scores=importance_scores,
)
class CustomFromMask(BasePruningMethod):
PRUNING_TYPE = "global"
def __init__(self, mask):
self.mask = mask
def compute_mask(self, t, default_mask):
assert default_mask.shape == self.mask.shape
mask = default_mask * self.mask.to(dtype=default_mask.dtype)
return mask
@classmethod
def apply(cls, module, name, mask):
r"""Add pruning on the fly and reparametrization of a tensor.
Adds the forward pre-hook that enables pruning on the fly and
the reparametrization of a tensor in terms of the original tensor
and the pruning mask.
Args:
module (nn.Module): module containing the tensor to prune
name (str): parameter name within ``module`` on which pruning
will act.
"""
return super().apply(module, name, mask=mask)
def identity(module, name):
r"""Apply pruning reparametrization without pruning any units.
Applies pruning reparametrization to the tensor corresponding to the
parameter called ``name`` in ``module`` without actually pruning any
units. Modifies module in place (and also return the modified module)
by:
1) adding a named buffer called ``name+'_mask'`` corresponding to the
binary mask applied to the parameter ``name`` by the pruning method.
2) replacing the parameter ``name`` by its pruned version, while the
original (unpruned) parameter is stored in a new parameter named
``name+'_orig'``.
Note:
The mask is a tensor of ones.
Args:
module (nn.Module): module containing the tensor to prune.
name (str): parameter name within ``module`` on which pruning
will act.
Returns:
module (nn.Module): modified (i.e. pruned) version of the input module
Examples:
>>> # xdoctest: +SKIP
>>> m = prune.identity(nn.Linear(2, 3), 'bias')
>>> print(m.bias_mask)
tensor([1., 1., 1.])
"""
Identity.apply(module, name)
return module
def random_unstructured(module, name, amount):
r"""Prune tensor by removing random (currently unpruned) units.
Prunes tensor corresponding to parameter called ``name`` in ``module``
by removing the specified ``amount`` of (currently unpruned) units
selected at random.
Modifies module in place (and also return the modified module) by:
1) adding a named buffer called ``name+'_mask'`` corresponding to the
binary mask applied to the parameter ``name`` by the pruning method.
2) replacing the parameter ``name`` by its pruned version, while the
original (unpruned) parameter is stored in a new parameter named
``name+'_orig'``.
Args:
module (nn.Module): module containing the tensor to prune
name (str): parameter name within ``module`` on which pruning
will act.
amount (int or float): quantity of parameters to prune.
If ``float``, should be between 0.0 and 1.0 and represent the
fraction of parameters to prune. If ``int``, it represents the
absolute number of parameters to prune.
Returns:
module (nn.Module): modified (i.e. pruned) version of the input module
Examples:
>>> # xdoctest: +SKIP
>>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1)
>>> torch.sum(m.weight_mask == 0)
tensor(1)
"""
RandomUnstructured.apply(module, name, amount)
return module
def l1_unstructured(module, name, amount, importance_scores=None):
r"""Prune tensor by removing units with the lowest L1-norm.
Prunes tensor corresponding to parameter called ``name`` in ``module``
by removing the specified `amount` of (currently unpruned) units with the
lowest L1-norm.
Modifies module in place (and also return the modified module)
by:
1) adding a named buffer called ``name+'_mask'`` corresponding to the
binary mask applied to the parameter ``name`` by the pruning method.
2) replacing the parameter ``name`` by its pruned version, while the
original (unpruned) parameter is stored in a new parameter named
``name+'_orig'``.
Args:
module (nn.Module): module containing the tensor to prune
name (str): parameter name within ``module`` on which pruning
will act.
amount (int or float): quantity of parameters to prune.
If ``float``, should be between 0.0 and 1.0 and represent the
fraction of parameters to prune. If ``int``, it represents the
absolute number of parameters to prune.
importance_scores (torch.Tensor): tensor of importance scores (of same
shape as module parameter) used to compute mask for pruning.
The values in this tensor indicate the importance of the corresponding
elements in the parameter being pruned.
If unspecified or None, the module parameter will be used in its place.
Returns:
module (nn.Module): modified (i.e. pruned) version of the input module
Examples:
>>> # xdoctest: +SKIP
>>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2)
>>> m.state_dict().keys()
odict_keys(['bias', 'weight_orig', 'weight_mask'])
"""
L1Unstructured.apply(
module, name, amount=amount, importance_scores=importance_scores
)
return module
def random_structured(module, name, amount, dim):
r"""Prune tensor by removing random channels along the specified dimension.
Prunes tensor corresponding to parameter called ``name`` in ``module``
by removing the specified ``amount`` of (currently unpruned) channels
along the specified ``dim`` selected at random.
Modifies module in place (and also return the modified module)
by:
1) adding a named buffer called ``name+'_mask'`` corresponding to the
binary mask applied to the parameter ``name`` by the pruning method.
2) replacing the parameter ``name`` by its pruned version, while the
original (unpruned) parameter is stored in a new parameter named
``name+'_orig'``.
Args:
module (nn.Module): module containing the tensor to prune
name (str): parameter name within ``module`` on which pruning
will act.
amount (int or float): quantity of parameters to prune.
If ``float``, should be between 0.0 and 1.0 and represent the
fraction of parameters to prune. If ``int``, it represents the
absolute number of parameters to prune.
dim (int): index of the dim along which we define channels to prune.
Returns:
module (nn.Module): modified (i.e. pruned) version of the input module
Examples:
>>> # xdoctest: +SKIP
>>> m = prune.random_structured(
... nn.Linear(5, 3), 'weight', amount=3, dim=1
... )
>>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0))
>>> print(columns_pruned)
3
"""
RandomStructured.apply(module, name, amount, dim)
return module
def ln_structured(module, name, amount, n, dim, importance_scores=None):
r"""Prune tensor by removing channels with the lowest L\ ``n``-norm along the specified dimension.
Prunes tensor corresponding to parameter called ``name`` in ``module``
by removing the specified ``amount`` of (currently unpruned) channels
along the specified ``dim`` with the lowest L\ ``n``-norm.
Modifies module in place (and also return the modified module)
by:
1) adding a named buffer called ``name+'_mask'`` corresponding to the
binary mask applied to the parameter ``name`` by the pruning method.
2) replacing the parameter ``name`` by its pruned version, while the
original (unpruned) parameter is stored in a new parameter named
``name+'_orig'``.
Args:
module (nn.Module): module containing the tensor to prune
name (str): parameter name within ``module`` on which pruning
will act.
amount (int or float): quantity of parameters to prune.
If ``float``, should be between 0.0 and 1.0 and represent the
fraction of parameters to prune. If ``int``, it represents the
absolute number of parameters to prune.
n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
entries for argument ``p`` in :func:`torch.norm`.
dim (int): index of the dim along which we define channels to prune.