CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/semi_structured_sparse.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
(beta) Accelerating BERT with semi-structured (2:4) sparsity
4
=====================================================
5
**Author**: `Jesse Cai <https://github.com/jcaip>`_
6
7
"""
8
9
####################################################################
10
# Overview
11
# --------
12
#
13
# Like other forms of sparsity, **semi-structured sparsity** is a model
14
# optimization technique that seeks to reduce the memory overhead and
15
# latency of a neural network at the expense of some model accuracy. It is
16
# also known as **fine-grained structured sparsity** or **2:4 structured
17
# sparsity**.
18
#
19
# Semi-structured sparsity derives its name from its unique sparsity
20
# pattern, where n out of every 2n elements are pruned. We most often see
21
# n=2, hence 2:4 sparsity Semi-structured sparsity is particularly
22
# interesting because it can be efficiently accelerated on GPUs and
23
# doesn’t degrade model accuracy as much as other sparsity patterns.
24
#
25
# With the introduction of
26
# `semi-structured sparsity support <https://pytorch.org/docs/2.1/sparse.html#sparse-semi-structured-tensors>`_,
27
# it is possible to prune and accelerate a semi-structured sparse model
28
# without leaving PyTorch. We will explain this process in this tutorial.
29
#
30
# .. image:: ../../_static/img/pruning_flow.jpg
31
#
32
# By the end of this tutorial, we will have sparsified a BERT
33
# question-answering model to be 2:4 sparse, fine-tuning it to recover
34
# nearly all F1 loss (86.92 dense vs 86.48 sparse). Finally, we will
35
# accelerate this 2:4 sparse model for inference, yielding a 1.3x speedup.
36
#
37
38
#####################################################
39
# Requirements
40
# ------------
41
#
42
# - PyTorch >= 2.1.
43
# - A NVIDIA GPU with semi-structured sparsity support (Compute
44
# Capability 8.0+).
45
#
46
# This tutorial is designed for beginners to semi-structured sparsity and
47
# sparsity in general. For users with existing 2:4 sparse models,
48
# accelerating ``nn.Linear`` layers for inference with
49
# ``to_sparse_semi_structured`` is quite straightforward. Here is an example:
50
#
51
52
import torch
53
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
54
from torch.utils.benchmark import Timer
55
SparseSemiStructuredTensor._FORCE_CUTLASS = True
56
57
# mask Linear weight to be 2:4 sparse
58
mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
59
linear = torch.nn.Linear(10240, 3072).half().cuda().eval()
60
linear.weight = torch.nn.Parameter(mask * linear.weight)
61
62
x = torch.rand(3072, 10240).half().cuda()
63
64
with torch.inference_mode():
65
dense_output = linear(x)
66
dense_t = Timer(stmt="linear(x)",
67
globals={"linear": linear,
68
"x": x}).blocked_autorange().median * 1e3
69
70
# accelerate via SparseSemiStructuredTensor
71
linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight))
72
73
sparse_output = linear(x)
74
sparse_t = Timer(stmt="linear(x)",
75
globals={"linear": linear,
76
"x": x}).blocked_autorange().median * 1e3
77
78
# sparse and dense matmul are numerically equivalent
79
# On an A100 80GB, we see: `Dense: 0.870ms Sparse: 0.630ms | Speedup: 1.382x`
80
assert torch.allclose(sparse_output, dense_output, atol=1e-3)
81
print(f"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")
82
83
84
######################################################################
85
# What problem does semi-structured sparsity solve?
86
# -------------------------------------------------
87
#
88
# The general motivation behind sparsity is simple: if there are zeros in
89
# your network, you can optimize efficiency by not storing or computing those
90
# parameters. However, the specifics of sparsity are tricky. Zeroing out
91
# parameters doesn’t affect the latency / memory overhead of our model out
92
# of the box.
93
#
94
# This is because the dense tensor still contains the pruned (zero)
95
# elements, which the dense matrix multiplication kernel will still
96
# operate on this elements. In order to realize performance gains, we need
97
# to swap out dense kernels for sparse kernels, which skip calculation
98
# involving pruned elements.
99
#
100
# To do this, these kernels work on sparse matrices, which do not store
101
# the pruned elements and store the specified elements in a compressed
102
# format.
103
#
104
# For semi-structured sparsity, we store exactly half of the original
105
# parameters along with some compressed metadata about how the elements
106
# were arranged.
107
#
108
# .. image:: https://developer-blogs.nvidia.com/wp-content/uploads/2023/06/2-4-structured-sparsity-pattern.png
109
# :align: center :width: 80%
110
#
111
# Image sourced from `NVIDIA blog post <https://developer.nvidia.com/blog/structured-sparsity-in-the-nvidia-ampere-architecture-and-applications-in-search-engines/>`_ on semi-structured sparsity.
112
#
113
# There are many different sparse layouts, each with their own benefits
114
# and drawbacks. The 2:4 semi-structured sparse layout is particularly
115
# interesting for two reasons:
116
#
117
# * Unlike previous sparse formats,
118
# semi-structured sparsity was designed to be efficiently accelerated on
119
# GPUs. In 2020, NVIDIA introduced hardware support for semi-structured
120
# sparsity with their Ampere architecture, and have also released fast
121
# sparse kernels via
122
# CUTLASS `cuSPARSELt <https://docs.nvidia.com/cuda/cusparselt/index.html>`__.
123
#
124
# * At the same time, semi-structured sparsity tends to have a milder
125
# impact on model accuracy compared to other sparse formats, especially
126
# when accounting for more advanced pruning / fine-tuning methods. NVIDIA
127
# has shown in their `white paper <https://arxiv.org/abs/2104.08378>`_
128
# that a simple paradigm of magnitude pruning once to be 2:4 sparse and
129
# then retraining the model yields nearly identical model accuracies.
130
#
131
# Semi-structured exists in a sweet spot, providing a 2x (theoretical)
132
# speedup at a much lower sparsity level (50%), while still being granular
133
# enough to preserve model accuracy.
134
#
135
# +---------------------+-------------+--------+------------+-------------+
136
# | Network | Data Set | Metric | Dense FP16 | Sparse FP16 |
137
# +=====================+=============+========+============+=============+
138
# | ResNet-50 | ImageNet | Top-1 | 76.1 | 76.2 |
139
# +---------------------+-------------+--------+------------+-------------+
140
# | ResNeXt-101_32x8d | ImageNet | Top-1 | 79.3 | 79.3 |
141
# +---------------------+-------------+--------+------------+-------------+
142
# | Xception | ImageNet | Top-1 | 79.2 | 79.2 |
143
# +---------------------+-------------+--------+------------+-------------+
144
# | SSD-RN50 | COCO2017 | bbAP | 24.8 | 24.8 |
145
# +---------------------+-------------+--------+------------+-------------+
146
# | MaskRCNN-RN50 | COCO2017 | bbAP | 37.9 | 37.9 |
147
# +---------------------+-------------+--------+------------+-------------+
148
# | FairSeq Transformer | EN-DE WMT14 | BLEU | 28.2 | 28.5 |
149
# +---------------------+-------------+--------+------------+-------------+
150
# | BERT-Large | SQuAD v1.1 | F1 | 91.9 | 91.9 |
151
# +---------------------+-------------+--------+------------+-------------+
152
#
153
# Semi-structured sparsity has an additional advantage from a workflow
154
# perspective. Because the sparsity level is fixed at 50%, it is easier to
155
# decompose the problem of sparsifying a model into two distinct
156
# subproblems:
157
#
158
# - Accuracy - How can we find a set of 2:4 sparse weights that minimize
159
# the accuracy degradation of our model?
160
#
161
# - Performance - How can we accelerate our 2:4 sparse weights for
162
# inference and reduced memory overhead?
163
#
164
165
#####################################################################
166
# .. math::
167
#
168
# \begin{bmatrix}
169
# 1 & 1 & 0 & 0 \\
170
# 0 & 0 & 1 & 1 \\
171
# 1 & 0 & 0 & 0 \\
172
# 0 & 0 & 1 & 1 \\
173
# \end{bmatrix}
174
#
175
# The natural handoff point between these two problems are zeroed-out
176
# dense tensors. Our inference solution is designed to compress and
177
# accelerate tensors in this format. We anticipate many users coming up
178
# with custom masking solution, as this is an active area of research.
179
#
180
# Now that we’ve learned a little more about semi-structured sparsity,
181
# let’s apply it to a BERT model trained on a question answering task,
182
# SQuAD.
183
#
184
# Intro & Setup
185
# -------------
186
#
187
# Let’s start by importing all the packages we need.
188
#
189
190
# If you are running this in Google Colab, run:
191
# .. code-block: python
192
#
193
# !pip install datasets transformers evaluate accelerate pandas
194
#
195
import os
196
os.environ["WANDB_DISABLED"] = "true"
197
198
import collections
199
import datasets
200
import evaluate
201
import numpy as np
202
import torch
203
import torch.utils.benchmark as benchmark
204
from torch import nn
205
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
206
from torch.ao.pruning import WeightNormSparsifier
207
import transformers
208
209
# force CUTLASS use if ``cuSPARSELt`` is not available
210
SparseSemiStructuredTensor._FORCE_CUTLASS = True
211
torch.manual_seed(100)
212
213
214
######################################################################
215
# We’ll also need to define some helper functions that are specific to the
216
# dataset / task at hand. These were adapted from
217
# `this <https://huggingface.co/learn/nlp-course/chapter7/7?fw=pt>`__
218
# Hugging Face course as a reference.
219
#
220
221
def preprocess_validation_function(examples, tokenizer):
222
inputs = tokenizer(
223
[q.strip() for q in examples["question"]],
224
examples["context"],
225
max_length=384,
226
truncation="only_second",
227
return_overflowing_tokens=True,
228
return_offsets_mapping=True,
229
padding="max_length",
230
)
231
sample_map = inputs.pop("overflow_to_sample_mapping")
232
example_ids = []
233
234
for i in range(len(inputs["input_ids"])):
235
sample_idx = sample_map[i]
236
example_ids.append(examples["id"][sample_idx])
237
sequence_ids = inputs.sequence_ids(i)
238
offset = inputs["offset_mapping"][i]
239
inputs["offset_mapping"][i] = [
240
o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
241
]
242
243
inputs["example_id"] = example_ids
244
return inputs
245
246
247
def preprocess_train_function(examples, tokenizer):
248
inputs = tokenizer(
249
[q.strip() for q in examples["question"]],
250
examples["context"],
251
max_length=384,
252
truncation="only_second",
253
return_offsets_mapping=True,
254
padding="max_length",
255
)
256
257
offset_mapping = inputs["offset_mapping"]
258
answers = examples["answers"]
259
start_positions = []
260
end_positions = []
261
262
for i, (offset, answer) in enumerate(zip(offset_mapping, answers)):
263
start_char = answer["answer_start"][0]
264
end_char = start_char + len(answer["text"][0])
265
sequence_ids = inputs.sequence_ids(i)
266
267
# Find the start and end of the context
268
idx = 0
269
while sequence_ids[idx] != 1:
270
idx += 1
271
context_start = idx
272
while sequence_ids[idx] == 1:
273
idx += 1
274
context_end = idx - 1
275
276
# If the answer is not fully inside the context, label it (0, 0)
277
if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
278
start_positions.append(0)
279
end_positions.append(0)
280
else:
281
# Otherwise it's the start and end token positions
282
idx = context_start
283
while idx <= context_end and offset[idx][0] <= start_char:
284
idx += 1
285
start_positions.append(idx - 1)
286
287
idx = context_end
288
while idx >= context_start and offset[idx][1] >= end_char:
289
idx -= 1
290
end_positions.append(idx + 1)
291
292
inputs["start_positions"] = start_positions
293
inputs["end_positions"] = end_positions
294
return inputs
295
296
297
def compute_metrics(start_logits, end_logits, features, examples):
298
n_best = 20
299
max_answer_length = 30
300
metric = evaluate.load("squad")
301
302
example_to_features = collections.defaultdict(list)
303
for idx, feature in enumerate(features):
304
example_to_features[feature["example_id"]].append(idx)
305
306
predicted_answers = []
307
# for example in ``tqdm`` (examples):
308
for example in examples:
309
example_id = example["id"]
310
context = example["context"]
311
answers = []
312
313
# Loop through all features associated with that example
314
for feature_index in example_to_features[example_id]:
315
start_logit = start_logits[feature_index]
316
end_logit = end_logits[feature_index]
317
offsets = features[feature_index]["offset_mapping"]
318
319
start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
320
end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
321
for start_index in start_indexes:
322
for end_index in end_indexes:
323
# Skip answers that are not fully in the context
324
if offsets[start_index] is None or offsets[end_index] is None:
325
continue
326
# Skip answers with a length that is either < 0
327
# or > max_answer_length
328
if (
329
end_index < start_index
330
or end_index - start_index + 1 > max_answer_length
331
):
332
continue
333
334
answer = {
335
"text": context[
336
offsets[start_index][0] : offsets[end_index][1]
337
],
338
"logit_score": start_logit[start_index] + end_logit[end_index],
339
}
340
answers.append(answer)
341
342
# Select the answer with the best score
343
if len(answers) > 0:
344
best_answer = max(answers, key=lambda x: x["logit_score"])
345
predicted_answers.append(
346
{"id": example_id, "prediction_text": best_answer["text"]}
347
)
348
else:
349
predicted_answers.append({"id": example_id, "prediction_text": ""})
350
351
theoretical_answers = [
352
{"id": ex["id"], "answers": ex["answers"]} for ex in examples
353
]
354
return metric.compute(predictions=predicted_answers, references=theoretical_answers)
355
356
357
######################################################################
358
# Now that those are defined, we just need one additional helper function,
359
# which will help us benchmark our model.
360
#
361
362
def measure_execution_time(model, batch_sizes, dataset):
363
dataset_for_model = dataset.remove_columns(["example_id", "offset_mapping"])
364
dataset_for_model.set_format("torch")
365
batch_size_to_time_sec = {}
366
for batch_size in batch_sizes:
367
batch = {
368
k: dataset_for_model[k][:batch_size].cuda()
369
for k in dataset_for_model.column_names
370
}
371
372
with torch.no_grad():
373
baseline_predictions = model(**batch)
374
timer = benchmark.Timer(
375
stmt="model(**batch)", globals={"model": model, "batch": batch}
376
)
377
p50 = timer.blocked_autorange().median * 1000
378
batch_size_to_time_sec[batch_size] = p50
379
380
model_c = torch.compile(model, fullgraph=True)
381
timer = benchmark.Timer(
382
stmt="model(**batch)", globals={"model": model_c, "batch": batch}
383
)
384
p50 = timer.blocked_autorange().median * 1000
385
batch_size_to_time_sec[f"{batch_size}_compile"] = p50
386
new_predictions = model_c(**batch)
387
388
return batch_size_to_time_sec
389
390
391
392
######################################################################
393
# We will get started by loading our model and tokenizer, and then setting
394
# up our dataset.
395
#
396
397
# load model
398
model_name = "bert-base-cased"
399
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
400
model = transformers.AutoModelForQuestionAnswering.from_pretrained(model_name)
401
print(f"Loading tokenizer: {model_name}")
402
print(f"Loading model: {model_name}")
403
404
# set up train and val dataset
405
squad_dataset = datasets.load_dataset("squad")
406
tokenized_squad_dataset = {}
407
tokenized_squad_dataset["train"] = squad_dataset["train"].map(
408
lambda x: preprocess_train_function(x, tokenizer), batched=True
409
)
410
tokenized_squad_dataset["validation"] = squad_dataset["validation"].map(
411
lambda x: preprocess_validation_function(x, tokenizer),
412
batched=True,
413
remove_columns=squad_dataset["train"].column_names,
414
)
415
data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer)
416
417
418
######################################################################
419
# Establishing a baseline
420
# =======================
421
#
422
# Next, we’ll train a quick baseline of our model on SQuAD. This task asks
423
# our model to identify spans, or segments of text, in a given context
424
# (Wikipedia articles) that answer a given question. Running the following
425
# code gives me an F1 score of 86.9. This is quite close to the reported
426
# NVIDIA score and the difference is likely due to BERT-base
427
# vs. BERT-large or fine-tuning hyperparameters.
428
#
429
430
training_args = transformers.TrainingArguments(
431
"trainer",
432
num_train_epochs=1,
433
lr_scheduler_type="constant",
434
per_device_train_batch_size=32,
435
per_device_eval_batch_size=256,
436
logging_steps=50,
437
# Limit max steps for tutorial runners. Delete the below line to see the reported accuracy numbers.
438
max_steps=500,
439
report_to=None,
440
)
441
442
trainer = transformers.Trainer(
443
model,
444
training_args,
445
train_dataset=tokenized_squad_dataset["train"],
446
eval_dataset=tokenized_squad_dataset["validation"],
447
data_collator=data_collator,
448
tokenizer=tokenizer,
449
)
450
451
trainer.train()
452
453
# batch sizes to compare for eval
454
batch_sizes = [4, 16, 64, 256]
455
# 2:4 sparsity require fp16, so we cast here for a fair comparison
456
with torch.autocast("cuda"):
457
with torch.no_grad():
458
predictions = trainer.predict(tokenized_squad_dataset["validation"])
459
start_logits, end_logits = predictions.predictions
460
fp16_baseline = compute_metrics(
461
start_logits,
462
end_logits,
463
tokenized_squad_dataset["validation"],
464
squad_dataset["validation"],
465
)
466
fp16_time = measure_execution_time(
467
model,
468
batch_sizes,
469
tokenized_squad_dataset["validation"],
470
)
471
472
print("fp16", fp16_baseline)
473
print("cuda_fp16 time", fp16_time)
474
475
import pandas as pd
476
df = pd.DataFrame(trainer.state.log_history)
477
df.plot.line(x='step', y='loss', title="Loss vs. # steps", ylabel="loss")
478
479
480
######################################################################
481
# Pruning BERT to be 2:4 sparse
482
# -----------------------------
483
#
484
# Now that we have our baseline, it’s time we prune BERT. There are many
485
# different pruning strategies, but one of the most common is **magnitude
486
# pruning**, which seeks to remove the weights with the lowest L1 norm.
487
# Magnitude pruning was used by NVIDIA in all their results and is a
488
# common baseline.
489
#
490
# To do this, we will use the ``torch.ao.pruning`` package, which contains
491
# a weight-norm (magnitude) sparsifier. These sparsifiers work by applying
492
# mask parametrizations to the weight tensors in a model. This lets them
493
# simulate sparsity by masking out the pruned weights.
494
#
495
# We’ll also have to decide what layers of the model to apply sparsity to,
496
# which in this case is all of the ``nn.Linear`` layers, except for the
497
# task-specific head outputs. That’s because semi-structured sparsity has
498
# `shape constraints <https://pytorch.org/docs/2.1/sparse.html#constructing-sparse-semi-structured-tensors>`_,
499
# and the task-specific ``nn.Linear`` layers do not satisfy them.
500
#
501
502
sparsifier = WeightNormSparsifier(
503
# apply sparsity to all blocks
504
sparsity_level=1.0,
505
# shape of 4 elements is a block
506
sparse_block_shape=(1, 4),
507
# two zeros for every block of 4
508
zeros_per_block=2
509
)
510
511
# add to config if ``nn.Linear`` and in the BERT model.
512
sparse_config = [
513
{"tensor_fqn": f"{fqn}.weight"}
514
for fqn, module in model.named_modules()
515
if isinstance(module, nn.Linear) and "layer" in fqn
516
]
517
518
519
######################################################################
520
# The first step for pruning the model is to insert parametrizations for
521
# masking the weights of the model. This is done by the prepare step.
522
# Anytime we try to access the ``.weight`` we will get ``mask * weight``
523
# instead.
524
#
525
526
# Prepare the model, insert fake-sparsity parametrizations for training
527
sparsifier.prepare(model, sparse_config)
528
print(model.bert.encoder.layer[0].output)
529
530
531
######################################################################
532
# Then, we’ll take a single pruning step. All pruners implement a
533
# ``update_mask()`` method that updates the mask with the logic being
534
# determined by the pruner implementation. The step method calls this
535
# ``update_mask`` functions for the weights specified in the sparse
536
# config.
537
#
538
# We will also evaluate the model to show the accuracy degradation of
539
# zero-shot pruning, or pruning without fine-tuning / retraining.
540
#
541
542
sparsifier.step()
543
with torch.autocast("cuda"):
544
with torch.no_grad():
545
predictions = trainer.predict(tokenized_squad_dataset["validation"])
546
pruned = compute_metrics(
547
*predictions.predictions,
548
tokenized_squad_dataset["validation"],
549
squad_dataset["validation"],
550
)
551
print("pruned eval metrics:", pruned)
552
553
554
######################################################################
555
# In this state, we can start fine-tuning the model, updating the elements
556
# that wouldn’t be pruned to better account for the accuracy loss. Once
557
# we’ve reached a satisfied state, we can call ``squash_mask`` to fuse the
558
# mask and the weight together. This will remove the parametrizations and
559
# we are left with a zeroed-out 2:4 dense model.
560
#
561
562
trainer.train()
563
sparsifier.squash_mask()
564
torch.set_printoptions(edgeitems=4)
565
print(model.bert.encoder.layer[0].intermediate.dense.weight[:8, :8])
566
567
df["sparse_loss"] = pd.DataFrame(trainer.state.log_history)["loss"]
568
df.plot.line(x='step', y=["loss", "sparse_loss"], title="Loss vs. # steps", ylabel="loss")
569
570
571
######################################################################
572
# Accelerating 2:4 sparse models for inference
573
# --------------------------------------------
574
#
575
# Now that we have a model in this format, we can accelerate it for
576
# inference just like in the QuickStart Guide.
577
#
578
579
model = model.cuda().half()
580
# accelerate for sparsity
581
for fqn, module in model.named_modules():
582
if isinstance(module, nn.Linear) and "layer" in fqn:
583
module.weight = nn.Parameter(to_sparse_semi_structured(module.weight))
584
585
with torch.no_grad():
586
predictions = trainer.predict(tokenized_squad_dataset["validation"])
587
start_logits, end_logits = predictions.predictions
588
metrics_sparse = compute_metrics(
589
start_logits,
590
end_logits,
591
tokenized_squad_dataset["validation"],
592
squad_dataset["validation"],
593
)
594
print("sparse eval metrics: ", metrics_sparse)
595
sparse_perf = measure_execution_time(
596
model,
597
batch_sizes,
598
tokenized_squad_dataset["validation"],
599
)
600
print("sparse perf metrics: ", sparse_perf)
601
602
603
######################################################################
604
# Retraining our model after magnitude pruning has recovered nearly all of
605
# the F1 that has been lost when the model was pruned. At the same time we
606
# have achieved a 1.28x speedup for ``bs=16``. Note that not all shapes are
607
# amenable to performance improvements. When batch sizes are small and
608
# limited time is spent in compute sparse kernels may be slower than their
609
# dense counterparts.
610
#
611
# Because semi-structured sparsity is implemented as a tensor subclass, it
612
# is compatible with ``torch.compile``. When composed with
613
# ``to_sparse_semi_structured``, we are able to achieve a total 2x speedup
614
# on BERT.
615
#
616
# .. table::
617
#
618
# +--------------------+--------+--------------+-----------------+-----------+
619
# | Metrics | fp16 | 2:4 sparse | delta / speedup | compiled |
620
# +====================+========+==============+=================+===========+
621
# | Exact Match (%) | 78.53 | 78.44 | -0.09 | |
622
# +--------------------+--------+--------------+-----------------+-----------+
623
# | F1 (%) | 86.93 | 86.49 | -0.44 | |
624
# +--------------------+--------+--------------+-----------------+-----------+
625
# | Time (bs=4) | 11.10 | 15.54 | 0.71x | no |
626
# +--------------------+--------+--------------+-----------------+-----------+
627
# | Time (bs=16) | 19.35 | 15.74 | 1.23x | no |
628
# +--------------------+--------+--------------+-----------------+-----------+
629
# | Time (bs=64) | 72.71 | 59.41 | 1.22x | no |
630
# +--------------------+--------+--------------+-----------------+-----------+
631
# | Time (bs=256) | 286.65 | 247.63 | 1.14x | no |
632
# +--------------------+--------+--------------+-----------------+-----------+
633
# | Time (bs=4) | 7.59 | 7.46 | 1.02x | yes |
634
# +--------------------+--------+--------------+-----------------+-----------+
635
# | Time (bs=16) | 11.47 | 9.68 | 1.18x | yes |
636
# +--------------------+--------+--------------+-----------------+-----------+
637
# | Time (bs=64) | 41.57 | 36.92 | 1.13x | yes |
638
# +--------------------+--------+--------------+-----------------+-----------+
639
# | Time (bs=256) | 159.22 | 142.23 | 1.12x | yes |
640
# +--------------------+--------+--------------+-----------------+-----------+
641
#
642
# Conclusion
643
# ==========
644
#
645
# In this tutorial, we have shown how to prune BERT to be 2:4 sparse and
646
# how to accelerate a 2:4 sparse model for inference. By taking advantage
647
# of our ``SparseSemiStructuredTensor`` subclass, we were able to achieve a
648
# 1.3x speedup over the fp16 baseline, and up to 2x with
649
# ``torch.compile``. We also demonstrated the benefits of 2:4 sparsity by
650
# fine-tuning BERT to recover any lost F1 (86.92 dense vs 86.48 sparse).
651
#
652
653