CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/semi_structured_sparse.py
Views: 1135
1
# -*- coding: utf-8 -*-
2
"""
3
(beta) Accelerating BERT with semi-structured (2:4) sparsity
4
=====================================================
5
**Author**: `Jesse Cai <https://github.com/jcaip>`_
6
7
"""
8
9
####################################################################
10
# Overview
11
# --------
12
#
13
# Like other forms of sparsity, **semi-structured sparsity** is a model
14
# optimization technique that seeks to reduce the memory overhead and
15
# latency of a neural network at the expense of some model accuracy. It is
16
# also known as **fine-grained structured sparsity** or **2:4 structured
17
# sparsity**.
18
#
19
# Semi-structured sparsity derives its name from its unique sparsity
20
# pattern, where n out of every 2n elements are pruned. We most often see
21
# n=2, hence 2:4 sparsity Semi-structured sparsity is particularly
22
# interesting because it can be efficiently accelerated on GPUs and
23
# doesn’t degrade model accuracy as much as other sparsity patterns.
24
#
25
# With the introduction of
26
# `semi-structured sparsity support <https://pytorch.org/docs/2.1/sparse.html#sparse-semi-structured-tensors>`_,
27
# it is possible to prune and accelerate a semi-structured sparse model
28
# without leaving PyTorch. We will explain this process in this tutorial.
29
#
30
# .. image:: ../../_static/img/pruning_flow.jpg
31
#
32
# By the end of this tutorial, we will have sparsified a BERT
33
# question-answering model to be 2:4 sparse, fine-tuning it to recover
34
# nearly all F1 loss (86.92 dense vs 86.48 sparse). Finally, we will
35
# accelerate this 2:4 sparse model for inference, yielding a 1.3x speedup.
36
#
37
38
#####################################################
39
# Requirements
40
# ------------
41
#
42
# - PyTorch >= 2.1.
43
# - A NVIDIA GPU with semi-structured sparsity support (Compute
44
# Capability 8.0+).
45
#
46
# This tutorial is designed for beginners to semi-structured sparsity and
47
# sparsity in general. For users with existing 2:4 sparse models,
48
# accelerating ``nn.Linear`` layers for inference with
49
# ``to_sparse_semi_structured`` is quite straightforward. Here is an example:
50
#
51
52
import torch
53
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
54
from torch.utils.benchmark import Timer
55
SparseSemiStructuredTensor._FORCE_CUTLASS = True
56
57
# mask Linear weight to be 2:4 sparse
58
mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
59
linear = torch.nn.Linear(10240, 3072).half().cuda().eval()
60
linear.weight = torch.nn.Parameter(mask * linear.weight)
61
62
x = torch.rand(3072, 10240).half().cuda()
63
64
with torch.inference_mode():
65
dense_output = linear(x)
66
dense_t = Timer(stmt="linear(x)",
67
globals={"linear": linear,
68
"x": x}).blocked_autorange().median * 1e3
69
70
# accelerate via SparseSemiStructuredTensor
71
linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight))
72
73
sparse_output = linear(x)
74
sparse_t = Timer(stmt="linear(x)",
75
globals={"linear": linear,
76
"x": x}).blocked_autorange().median * 1e3
77
78
# sparse and dense matmul are numerically equivalent
79
# On an A100 80GB, we see: `Dense: 0.870ms Sparse: 0.630ms | Speedup: 1.382x`
80
assert torch.allclose(sparse_output, dense_output, atol=1e-3)
81
print(f"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")
82
83
84
######################################################################
85
# What problem does semi-structured sparsity solve?
86
# -------------------------------------------------
87
#
88
# The general motivation behind sparsity is simple: if there are zeros in
89
# your network, you can optimize efficiency by not storing or computing those
90
# parameters. However, the specifics of sparsity are tricky. Zeroing out
91
# parameters doesn’t affect the latency / memory overhead of our model out
92
# of the box.
93
#
94
# This is because the dense tensor still contains the pruned (zero)
95
# elements, which the dense matrix multiplication kernel will still
96
# operate on this elements. In order to realize performance gains, we need
97
# to swap out dense kernels for sparse kernels, which skip calculation
98
# involving pruned elements.
99
#
100
# To do this, these kernels work on sparse matrices, which do not store
101
# the pruned elements and store the specified elements in a compressed
102
# format.
103
#
104
# For semi-structured sparsity, we store exactly half of the original
105
# parameters along with some compressed metadata about how the elements
106
# were arranged.
107
#
108
# .. image:: https://developer-blogs.nvidia.com/wp-content/uploads/2023/06/2-4-structured-sparsity-pattern.png
109
# :align: center :width: 80%
110
#
111
# Image sourced from `NVIDIA blog post <https://developer.nvidia.com/blog/structured-sparsity-in-the-nvidia-ampere-architecture-and-applications-in-search-engines/>`_ on semi-structured sparsity.
112
#
113
# There are many different sparse layouts, each with their own benefits
114
# and drawbacks. The 2:4 semi-structured sparse layout is particularly
115
# interesting for two reasons:
116
#
117
# * Unlike previous sparse formats,
118
# semi-structured sparsity was designed to be efficiently accelerated on
119
# GPUs. In 2020, NVIDIA introduced hardware support for semi-structured
120
# sparsity with their Ampere architecture, and have also released fast
121
# sparse kernels via
122
# CUTLASS `cuSPARSELt <https://docs.nvidia.com/cuda/cusparselt/index.html>`__.
123
#
124
# * At the same time, semi-structured sparsity tends to have a milder
125
# impact on model accuracy compared to other sparse formats, especially
126
# when accounting for more advanced pruning / fine-tuning methods. NVIDIA
127
# has shown in their `white paper <https://arxiv.org/abs/2104.08378>`_
128
# that a simple paradigm of magnitude pruning once to be 2:4 sparse and
129
# then retraining the model yields nearly identical model accuracies.
130
#
131
# Semi-structured exists in a sweet spot, providing a 2x (theoretical)
132
# speedup at a much lower sparsity level (50%), while still being granular
133
# enough to preserve model accuracy.
134
#
135
# +---------------------+-------------+--------+------------+-------------+
136
# | Network | Data Set | Metric | Dense FP16 | Sparse FP16 |
137
# +=====================+=============+========+============+=============+
138
# | ResNet-50 | ImageNet | Top-1 | 76.1 | 76.2 |
139
# +---------------------+-------------+--------+------------+-------------+
140
# | ResNeXt-101_32x8d | ImageNet | Top-1 | 79.3 | 79.3 |
141
# +---------------------+-------------+--------+------------+-------------+
142
# | Xception | ImageNet | Top-1 | 79.2 | 79.2 |
143
# +---------------------+-------------+--------+------------+-------------+
144
# | SSD-RN50 | COCO2017 | bbAP | 24.8 | 24.8 |
145
# +---------------------+-------------+--------+------------+-------------+
146
# | MaskRCNN-RN50 | COCO2017 | bbAP | 37.9 | 37.9 |
147
# +---------------------+-------------+--------+------------+-------------+
148
# | FairSeq Transformer | EN-DE WMT14 | BLEU | 28.2 | 28.5 |
149
# +---------------------+-------------+--------+------------+-------------+
150
# | BERT-Large | SQuAD v1.1 | F1 | 91.9 | 91.9 |
151
# +---------------------+-------------+--------+------------+-------------+
152
#
153
# Semi-structured sparsity has an additional advantage from a workflow
154
# perspective. Because the sparsity level is fixed at 50%, it is easier to
155
# decompose the problem of sparsifying a model into two distinct
156
# subproblems:
157
#
158
# - Accuracy - How can we find a set of 2:4 sparse weights that minimize
159
# the accuracy degradation of our model?
160
#
161
# - Performance - How can we accelerate our 2:4 sparse weights for
162
# inference and reduced memory overhead?
163
#
164
165
#####################################################################
166
# .. math::
167
#
168
# \begin{bmatrix}
169
# 1 & 1 & 0 & 0 \\
170
# 0 & 0 & 1 & 1 \\
171
# 1 & 0 & 0 & 0 \\
172
# 0 & 0 & 1 & 1 \\
173
# \end{bmatrix}
174
#
175
# The natural handoff point between these two problems are zeroed-out
176
# dense tensors. Our inference solution is designed to compress and
177
# accelerate tensors in this format. We anticipate many users coming up
178
# with custom masking solution, as this is an active area of research.
179
#
180
# Now that we’ve learned a little more about semi-structured sparsity,
181
# let’s apply it to a BERT model trained on a question answering task,
182
# SQuAD.
183
#
184
# Intro & Setup
185
# -------------
186
#
187
# Let’s start by importing all the packages we need.
188
#
189
190
# If you are running this in Google Colab, run:
191
# .. code-block: python
192
#
193
# !pip install datasets transformers evaluate accelerate pandas
194
#
195
import os
196
os.environ["WANDB_DISABLED"] = "true"
197
198
import collections
199
import datasets
200
import evaluate
201
import numpy as np
202
import torch
203
import torch.utils.benchmark as benchmark
204
from torch import nn
205
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
206
from torch.ao.pruning import WeightNormSparsifier
207
import transformers
208
209
# force CUTLASS use if ``cuSPARSELt`` is not available
210
SparseSemiStructuredTensor._FORCE_CUTLASS = True
211
torch.manual_seed(100)
212
213
# Set default device to "cuda:0"
214
torch.set_default_device(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
215
216
######################################################################
217
# We’ll also need to define some helper functions that are specific to the
218
# dataset / task at hand. These were adapted from
219
# `this <https://huggingface.co/learn/nlp-course/chapter7/7?fw=pt>`__
220
# Hugging Face course as a reference.
221
#
222
223
def preprocess_validation_function(examples, tokenizer):
224
inputs = tokenizer(
225
[q.strip() for q in examples["question"]],
226
examples["context"],
227
max_length=384,
228
truncation="only_second",
229
return_overflowing_tokens=True,
230
return_offsets_mapping=True,
231
padding="max_length",
232
)
233
sample_map = inputs.pop("overflow_to_sample_mapping")
234
example_ids = []
235
236
for i in range(len(inputs["input_ids"])):
237
sample_idx = sample_map[i]
238
example_ids.append(examples["id"][sample_idx])
239
sequence_ids = inputs.sequence_ids(i)
240
offset = inputs["offset_mapping"][i]
241
inputs["offset_mapping"][i] = [
242
o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
243
]
244
245
inputs["example_id"] = example_ids
246
return inputs
247
248
249
def preprocess_train_function(examples, tokenizer):
250
inputs = tokenizer(
251
[q.strip() for q in examples["question"]],
252
examples["context"],
253
max_length=384,
254
truncation="only_second",
255
return_offsets_mapping=True,
256
padding="max_length",
257
)
258
259
offset_mapping = inputs["offset_mapping"]
260
answers = examples["answers"]
261
start_positions = []
262
end_positions = []
263
264
for i, (offset, answer) in enumerate(zip(offset_mapping, answers)):
265
start_char = answer["answer_start"][0]
266
end_char = start_char + len(answer["text"][0])
267
sequence_ids = inputs.sequence_ids(i)
268
269
# Find the start and end of the context
270
idx = 0
271
while sequence_ids[idx] != 1:
272
idx += 1
273
context_start = idx
274
while sequence_ids[idx] == 1:
275
idx += 1
276
context_end = idx - 1
277
278
# If the answer is not fully inside the context, label it (0, 0)
279
if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
280
start_positions.append(0)
281
end_positions.append(0)
282
else:
283
# Otherwise it's the start and end token positions
284
idx = context_start
285
while idx <= context_end and offset[idx][0] <= start_char:
286
idx += 1
287
start_positions.append(idx - 1)
288
289
idx = context_end
290
while idx >= context_start and offset[idx][1] >= end_char:
291
idx -= 1
292
end_positions.append(idx + 1)
293
294
inputs["start_positions"] = start_positions
295
inputs["end_positions"] = end_positions
296
return inputs
297
298
299
def compute_metrics(start_logits, end_logits, features, examples):
300
n_best = 20
301
max_answer_length = 30
302
metric = evaluate.load("squad")
303
304
example_to_features = collections.defaultdict(list)
305
for idx, feature in enumerate(features):
306
example_to_features[feature["example_id"]].append(idx)
307
308
predicted_answers = []
309
# for example in ``tqdm`` (examples):
310
for example in examples:
311
example_id = example["id"]
312
context = example["context"]
313
answers = []
314
315
# Loop through all features associated with that example
316
for feature_index in example_to_features[example_id]:
317
start_logit = start_logits[feature_index]
318
end_logit = end_logits[feature_index]
319
offsets = features[feature_index]["offset_mapping"]
320
321
start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
322
end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
323
for start_index in start_indexes:
324
for end_index in end_indexes:
325
# Skip answers that are not fully in the context
326
if offsets[start_index] is None or offsets[end_index] is None:
327
continue
328
# Skip answers with a length that is either < 0
329
# or > max_answer_length
330
if (
331
end_index < start_index
332
or end_index - start_index + 1 > max_answer_length
333
):
334
continue
335
336
answer = {
337
"text": context[
338
offsets[start_index][0] : offsets[end_index][1]
339
],
340
"logit_score": start_logit[start_index] + end_logit[end_index],
341
}
342
answers.append(answer)
343
344
# Select the answer with the best score
345
if len(answers) > 0:
346
best_answer = max(answers, key=lambda x: x["logit_score"])
347
predicted_answers.append(
348
{"id": example_id, "prediction_text": best_answer["text"]}
349
)
350
else:
351
predicted_answers.append({"id": example_id, "prediction_text": ""})
352
353
theoretical_answers = [
354
{"id": ex["id"], "answers": ex["answers"]} for ex in examples
355
]
356
return metric.compute(predictions=predicted_answers, references=theoretical_answers)
357
358
359
######################################################################
360
# Now that those are defined, we just need one additional helper function,
361
# which will help us benchmark our model.
362
#
363
364
def measure_execution_time(model, batch_sizes, dataset):
365
dataset_for_model = dataset.remove_columns(["example_id", "offset_mapping"])
366
dataset_for_model.set_format("torch")
367
batch_size_to_time_sec = {}
368
for batch_size in batch_sizes:
369
batch = {
370
k: dataset_for_model[k][:batch_size].cuda()
371
for k in dataset_for_model.column_names
372
}
373
374
with torch.no_grad():
375
baseline_predictions = model(**batch)
376
timer = benchmark.Timer(
377
stmt="model(**batch)", globals={"model": model, "batch": batch}
378
)
379
p50 = timer.blocked_autorange().median * 1000
380
batch_size_to_time_sec[batch_size] = p50
381
382
model_c = torch.compile(model, fullgraph=True)
383
timer = benchmark.Timer(
384
stmt="model(**batch)", globals={"model": model_c, "batch": batch}
385
)
386
p50 = timer.blocked_autorange().median * 1000
387
batch_size_to_time_sec[f"{batch_size}_compile"] = p50
388
new_predictions = model_c(**batch)
389
390
return batch_size_to_time_sec
391
392
393
394
######################################################################
395
# We will get started by loading our model and tokenizer, and then setting
396
# up our dataset.
397
#
398
399
# load model
400
model_name = "bert-base-cased"
401
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
402
model = transformers.AutoModelForQuestionAnswering.from_pretrained(model_name)
403
print(f"Loading tokenizer: {model_name}")
404
print(f"Loading model: {model_name}")
405
406
# set up train and val dataset
407
squad_dataset = datasets.load_dataset("squad")
408
tokenized_squad_dataset = {}
409
tokenized_squad_dataset["train"] = squad_dataset["train"].map(
410
lambda x: preprocess_train_function(x, tokenizer), batched=True
411
)
412
tokenized_squad_dataset["validation"] = squad_dataset["validation"].map(
413
lambda x: preprocess_validation_function(x, tokenizer),
414
batched=True,
415
remove_columns=squad_dataset["train"].column_names,
416
)
417
data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer)
418
419
420
######################################################################
421
# Establishing a baseline
422
# =======================
423
#
424
# Next, we’ll train a quick baseline of our model on SQuAD. This task asks
425
# our model to identify spans, or segments of text, in a given context
426
# (Wikipedia articles) that answer a given question. Running the following
427
# code gives me an F1 score of 86.9. This is quite close to the reported
428
# NVIDIA score and the difference is likely due to BERT-base
429
# vs. BERT-large or fine-tuning hyperparameters.
430
#
431
432
training_args = transformers.TrainingArguments(
433
"trainer",
434
num_train_epochs=1,
435
lr_scheduler_type="constant",
436
per_device_train_batch_size=32,
437
per_device_eval_batch_size=256,
438
logging_steps=50,
439
# Limit max steps for tutorial runners. Delete the below line to see the reported accuracy numbers.
440
max_steps=500,
441
report_to=None,
442
)
443
444
trainer = transformers.Trainer(
445
model,
446
training_args,
447
train_dataset=tokenized_squad_dataset["train"],
448
eval_dataset=tokenized_squad_dataset["validation"],
449
data_collator=data_collator,
450
tokenizer=tokenizer,
451
)
452
453
trainer.train()
454
455
# batch sizes to compare for eval
456
batch_sizes = [4, 16, 64, 256]
457
# 2:4 sparsity require fp16, so we cast here for a fair comparison
458
with torch.autocast("cuda"):
459
with torch.no_grad():
460
predictions = trainer.predict(tokenized_squad_dataset["validation"])
461
start_logits, end_logits = predictions.predictions
462
fp16_baseline = compute_metrics(
463
start_logits,
464
end_logits,
465
tokenized_squad_dataset["validation"],
466
squad_dataset["validation"],
467
)
468
fp16_time = measure_execution_time(
469
model,
470
batch_sizes,
471
tokenized_squad_dataset["validation"],
472
)
473
474
print("fp16", fp16_baseline)
475
print("cuda_fp16 time", fp16_time)
476
477
import pandas as pd
478
df = pd.DataFrame(trainer.state.log_history)
479
df.plot.line(x='step', y='loss', title="Loss vs. # steps", ylabel="loss")
480
481
482
######################################################################
483
# Pruning BERT to be 2:4 sparse
484
# -----------------------------
485
#
486
# Now that we have our baseline, it’s time we prune BERT. There are many
487
# different pruning strategies, but one of the most common is **magnitude
488
# pruning**, which seeks to remove the weights with the lowest L1 norm.
489
# Magnitude pruning was used by NVIDIA in all their results and is a
490
# common baseline.
491
#
492
# To do this, we will use the ``torch.ao.pruning`` package, which contains
493
# a weight-norm (magnitude) sparsifier. These sparsifiers work by applying
494
# mask parametrizations to the weight tensors in a model. This lets them
495
# simulate sparsity by masking out the pruned weights.
496
#
497
# We’ll also have to decide what layers of the model to apply sparsity to,
498
# which in this case is all of the ``nn.Linear`` layers, except for the
499
# task-specific head outputs. That’s because semi-structured sparsity has
500
# `shape constraints <https://pytorch.org/docs/2.1/sparse.html#constructing-sparse-semi-structured-tensors>`_,
501
# and the task-specific ``nn.Linear`` layers do not satisfy them.
502
#
503
504
sparsifier = WeightNormSparsifier(
505
# apply sparsity to all blocks
506
sparsity_level=1.0,
507
# shape of 4 elements is a block
508
sparse_block_shape=(1, 4),
509
# two zeros for every block of 4
510
zeros_per_block=2
511
)
512
513
# add to config if ``nn.Linear`` and in the BERT model.
514
sparse_config = [
515
{"tensor_fqn": f"{fqn}.weight"}
516
for fqn, module in model.named_modules()
517
if isinstance(module, nn.Linear) and "layer" in fqn
518
]
519
520
521
######################################################################
522
# The first step for pruning the model is to insert parametrizations for
523
# masking the weights of the model. This is done by the prepare step.
524
# Anytime we try to access the ``.weight`` we will get ``mask * weight``
525
# instead.
526
#
527
528
# Prepare the model, insert fake-sparsity parametrizations for training
529
sparsifier.prepare(model, sparse_config)
530
print(model.bert.encoder.layer[0].output)
531
532
533
######################################################################
534
# Then, we’ll take a single pruning step. All pruners implement a
535
# ``update_mask()`` method that updates the mask with the logic being
536
# determined by the pruner implementation. The step method calls this
537
# ``update_mask`` functions for the weights specified in the sparse
538
# config.
539
#
540
# We will also evaluate the model to show the accuracy degradation of
541
# zero-shot pruning, or pruning without fine-tuning / retraining.
542
#
543
544
sparsifier.step()
545
with torch.autocast("cuda"):
546
with torch.no_grad():
547
predictions = trainer.predict(tokenized_squad_dataset["validation"])
548
pruned = compute_metrics(
549
*predictions.predictions,
550
tokenized_squad_dataset["validation"],
551
squad_dataset["validation"],
552
)
553
print("pruned eval metrics:", pruned)
554
555
556
######################################################################
557
# In this state, we can start fine-tuning the model, updating the elements
558
# that wouldn’t be pruned to better account for the accuracy loss. Once
559
# we’ve reached a satisfied state, we can call ``squash_mask`` to fuse the
560
# mask and the weight together. This will remove the parametrizations and
561
# we are left with a zeroed-out 2:4 dense model.
562
#
563
564
trainer.train()
565
sparsifier.squash_mask()
566
torch.set_printoptions(edgeitems=4)
567
print(model.bert.encoder.layer[0].intermediate.dense.weight[:8, :8])
568
569
df["sparse_loss"] = pd.DataFrame(trainer.state.log_history)["loss"]
570
df.plot.line(x='step', y=["loss", "sparse_loss"], title="Loss vs. # steps", ylabel="loss")
571
572
573
######################################################################
574
# Accelerating 2:4 sparse models for inference
575
# --------------------------------------------
576
#
577
# Now that we have a model in this format, we can accelerate it for
578
# inference just like in the QuickStart Guide.
579
#
580
581
model = model.cuda().half()
582
# accelerate for sparsity
583
for fqn, module in model.named_modules():
584
if isinstance(module, nn.Linear) and "layer" in fqn:
585
module.weight = nn.Parameter(to_sparse_semi_structured(module.weight))
586
587
with torch.no_grad():
588
predictions = trainer.predict(tokenized_squad_dataset["validation"])
589
start_logits, end_logits = predictions.predictions
590
metrics_sparse = compute_metrics(
591
start_logits,
592
end_logits,
593
tokenized_squad_dataset["validation"],
594
squad_dataset["validation"],
595
)
596
print("sparse eval metrics: ", metrics_sparse)
597
sparse_perf = measure_execution_time(
598
model,
599
batch_sizes,
600
tokenized_squad_dataset["validation"],
601
)
602
print("sparse perf metrics: ", sparse_perf)
603
604
605
######################################################################
606
# Retraining our model after magnitude pruning has recovered nearly all of
607
# the F1 that has been lost when the model was pruned. At the same time we
608
# have achieved a 1.28x speedup for ``bs=16``. Note that not all shapes are
609
# amenable to performance improvements. When batch sizes are small and
610
# limited time is spent in compute sparse kernels may be slower than their
611
# dense counterparts.
612
#
613
# Because semi-structured sparsity is implemented as a tensor subclass, it
614
# is compatible with ``torch.compile``. When composed with
615
# ``to_sparse_semi_structured``, we are able to achieve a total 2x speedup
616
# on BERT.
617
#
618
# .. table::
619
#
620
# +--------------------+--------+--------------+-----------------+-----------+
621
# | Metrics | fp16 | 2:4 sparse | delta / speedup | compiled |
622
# +====================+========+==============+=================+===========+
623
# | Exact Match (%) | 78.53 | 78.44 | -0.09 | |
624
# +--------------------+--------+--------------+-----------------+-----------+
625
# | F1 (%) | 86.93 | 86.49 | -0.44 | |
626
# +--------------------+--------+--------------+-----------------+-----------+
627
# | Time (bs=4) | 11.10 | 15.54 | 0.71x | no |
628
# +--------------------+--------+--------------+-----------------+-----------+
629
# | Time (bs=16) | 19.35 | 15.74 | 1.23x | no |
630
# +--------------------+--------+--------------+-----------------+-----------+
631
# | Time (bs=64) | 72.71 | 59.41 | 1.22x | no |
632
# +--------------------+--------+--------------+-----------------+-----------+
633
# | Time (bs=256) | 286.65 | 247.63 | 1.14x | no |
634
# +--------------------+--------+--------------+-----------------+-----------+
635
# | Time (bs=4) | 7.59 | 7.46 | 1.02x | yes |
636
# +--------------------+--------+--------------+-----------------+-----------+
637
# | Time (bs=16) | 11.47 | 9.68 | 1.18x | yes |
638
# +--------------------+--------+--------------+-----------------+-----------+
639
# | Time (bs=64) | 41.57 | 36.92 | 1.13x | yes |
640
# +--------------------+--------+--------------+-----------------+-----------+
641
# | Time (bs=256) | 159.22 | 142.23 | 1.12x | yes |
642
# +--------------------+--------+--------------+-----------------+-----------+
643
#
644
# Conclusion
645
# ==========
646
#
647
# In this tutorial, we have shown how to prune BERT to be 2:4 sparse and
648
# how to accelerate a 2:4 sparse model for inference. By taking advantage
649
# of our ``SparseSemiStructuredTensor`` subclass, we were able to achieve a
650
# 1.3x speedup over the fp16 baseline, and up to 2x with
651
# ``torch.compile``. We also demonstrated the benefits of 2:4 sparsity by
652
# fine-tuning BERT to recover any lost F1 (86.92 dense vs 86.48 sparse).
653
#
654
655