Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/advanced_source/semi_structured_sparse.py
Views: 1135
# -*- coding: utf-8 -*-1"""2(beta) Accelerating BERT with semi-structured (2:4) sparsity3=====================================================4**Author**: `Jesse Cai <https://github.com/jcaip>`_56"""78####################################################################9# Overview10# --------11#12# Like other forms of sparsity, **semi-structured sparsity** is a model13# optimization technique that seeks to reduce the memory overhead and14# latency of a neural network at the expense of some model accuracy. It is15# also known as **fine-grained structured sparsity** or **2:4 structured16# sparsity**.17#18# Semi-structured sparsity derives its name from its unique sparsity19# pattern, where n out of every 2n elements are pruned. We most often see20# n=2, hence 2:4 sparsity Semi-structured sparsity is particularly21# interesting because it can be efficiently accelerated on GPUs and22# doesn’t degrade model accuracy as much as other sparsity patterns.23#24# With the introduction of25# `semi-structured sparsity support <https://pytorch.org/docs/2.1/sparse.html#sparse-semi-structured-tensors>`_,26# it is possible to prune and accelerate a semi-structured sparse model27# without leaving PyTorch. We will explain this process in this tutorial.28#29# .. image:: ../../_static/img/pruning_flow.jpg30#31# By the end of this tutorial, we will have sparsified a BERT32# question-answering model to be 2:4 sparse, fine-tuning it to recover33# nearly all F1 loss (86.92 dense vs 86.48 sparse). Finally, we will34# accelerate this 2:4 sparse model for inference, yielding a 1.3x speedup.35#3637#####################################################38# Requirements39# ------------40#41# - PyTorch >= 2.1.42# - A NVIDIA GPU with semi-structured sparsity support (Compute43# Capability 8.0+).44#45# This tutorial is designed for beginners to semi-structured sparsity and46# sparsity in general. For users with existing 2:4 sparse models,47# accelerating ``nn.Linear`` layers for inference with48# ``to_sparse_semi_structured`` is quite straightforward. Here is an example:49#5051import torch52from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor53from torch.utils.benchmark import Timer54SparseSemiStructuredTensor._FORCE_CUTLASS = True5556# mask Linear weight to be 2:4 sparse57mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()58linear = torch.nn.Linear(10240, 3072).half().cuda().eval()59linear.weight = torch.nn.Parameter(mask * linear.weight)6061x = torch.rand(3072, 10240).half().cuda()6263with torch.inference_mode():64dense_output = linear(x)65dense_t = Timer(stmt="linear(x)",66globals={"linear": linear,67"x": x}).blocked_autorange().median * 1e36869# accelerate via SparseSemiStructuredTensor70linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight))7172sparse_output = linear(x)73sparse_t = Timer(stmt="linear(x)",74globals={"linear": linear,75"x": x}).blocked_autorange().median * 1e37677# sparse and dense matmul are numerically equivalent78# On an A100 80GB, we see: `Dense: 0.870ms Sparse: 0.630ms | Speedup: 1.382x`79assert torch.allclose(sparse_output, dense_output, atol=1e-3)80print(f"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")818283######################################################################84# What problem does semi-structured sparsity solve?85# -------------------------------------------------86#87# The general motivation behind sparsity is simple: if there are zeros in88# your network, you can optimize efficiency by not storing or computing those89# parameters. However, the specifics of sparsity are tricky. Zeroing out90# parameters doesn’t affect the latency / memory overhead of our model out91# of the box.92#93# This is because the dense tensor still contains the pruned (zero)94# elements, which the dense matrix multiplication kernel will still95# operate on this elements. In order to realize performance gains, we need96# to swap out dense kernels for sparse kernels, which skip calculation97# involving pruned elements.98#99# To do this, these kernels work on sparse matrices, which do not store100# the pruned elements and store the specified elements in a compressed101# format.102#103# For semi-structured sparsity, we store exactly half of the original104# parameters along with some compressed metadata about how the elements105# were arranged.106#107# .. image:: https://developer-blogs.nvidia.com/wp-content/uploads/2023/06/2-4-structured-sparsity-pattern.png108# :align: center :width: 80%109#110# Image sourced from `NVIDIA blog post <https://developer.nvidia.com/blog/structured-sparsity-in-the-nvidia-ampere-architecture-and-applications-in-search-engines/>`_ on semi-structured sparsity.111#112# There are many different sparse layouts, each with their own benefits113# and drawbacks. The 2:4 semi-structured sparse layout is particularly114# interesting for two reasons:115#116# * Unlike previous sparse formats,117# semi-structured sparsity was designed to be efficiently accelerated on118# GPUs. In 2020, NVIDIA introduced hardware support for semi-structured119# sparsity with their Ampere architecture, and have also released fast120# sparse kernels via121# CUTLASS `cuSPARSELt <https://docs.nvidia.com/cuda/cusparselt/index.html>`__.122#123# * At the same time, semi-structured sparsity tends to have a milder124# impact on model accuracy compared to other sparse formats, especially125# when accounting for more advanced pruning / fine-tuning methods. NVIDIA126# has shown in their `white paper <https://arxiv.org/abs/2104.08378>`_127# that a simple paradigm of magnitude pruning once to be 2:4 sparse and128# then retraining the model yields nearly identical model accuracies.129#130# Semi-structured exists in a sweet spot, providing a 2x (theoretical)131# speedup at a much lower sparsity level (50%), while still being granular132# enough to preserve model accuracy.133#134# +---------------------+-------------+--------+------------+-------------+135# | Network | Data Set | Metric | Dense FP16 | Sparse FP16 |136# +=====================+=============+========+============+=============+137# | ResNet-50 | ImageNet | Top-1 | 76.1 | 76.2 |138# +---------------------+-------------+--------+------------+-------------+139# | ResNeXt-101_32x8d | ImageNet | Top-1 | 79.3 | 79.3 |140# +---------------------+-------------+--------+------------+-------------+141# | Xception | ImageNet | Top-1 | 79.2 | 79.2 |142# +---------------------+-------------+--------+------------+-------------+143# | SSD-RN50 | COCO2017 | bbAP | 24.8 | 24.8 |144# +---------------------+-------------+--------+------------+-------------+145# | MaskRCNN-RN50 | COCO2017 | bbAP | 37.9 | 37.9 |146# +---------------------+-------------+--------+------------+-------------+147# | FairSeq Transformer | EN-DE WMT14 | BLEU | 28.2 | 28.5 |148# +---------------------+-------------+--------+------------+-------------+149# | BERT-Large | SQuAD v1.1 | F1 | 91.9 | 91.9 |150# +---------------------+-------------+--------+------------+-------------+151#152# Semi-structured sparsity has an additional advantage from a workflow153# perspective. Because the sparsity level is fixed at 50%, it is easier to154# decompose the problem of sparsifying a model into two distinct155# subproblems:156#157# - Accuracy - How can we find a set of 2:4 sparse weights that minimize158# the accuracy degradation of our model?159#160# - Performance - How can we accelerate our 2:4 sparse weights for161# inference and reduced memory overhead?162#163164#####################################################################165# .. math::166#167# \begin{bmatrix}168# 1 & 1 & 0 & 0 \\169# 0 & 0 & 1 & 1 \\170# 1 & 0 & 0 & 0 \\171# 0 & 0 & 1 & 1 \\172# \end{bmatrix}173#174# The natural handoff point between these two problems are zeroed-out175# dense tensors. Our inference solution is designed to compress and176# accelerate tensors in this format. We anticipate many users coming up177# with custom masking solution, as this is an active area of research.178#179# Now that we’ve learned a little more about semi-structured sparsity,180# let’s apply it to a BERT model trained on a question answering task,181# SQuAD.182#183# Intro & Setup184# -------------185#186# Let’s start by importing all the packages we need.187#188189# If you are running this in Google Colab, run:190# .. code-block: python191#192# !pip install datasets transformers evaluate accelerate pandas193#194import os195os.environ["WANDB_DISABLED"] = "true"196197import collections198import datasets199import evaluate200import numpy as np201import torch202import torch.utils.benchmark as benchmark203from torch import nn204from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor205from torch.ao.pruning import WeightNormSparsifier206import transformers207208# force CUTLASS use if ``cuSPARSELt`` is not available209SparseSemiStructuredTensor._FORCE_CUTLASS = True210torch.manual_seed(100)211212# Set default device to "cuda:0"213torch.set_default_device(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))214215######################################################################216# We’ll also need to define some helper functions that are specific to the217# dataset / task at hand. These were adapted from218# `this <https://huggingface.co/learn/nlp-course/chapter7/7?fw=pt>`__219# Hugging Face course as a reference.220#221222def preprocess_validation_function(examples, tokenizer):223inputs = tokenizer(224[q.strip() for q in examples["question"]],225examples["context"],226max_length=384,227truncation="only_second",228return_overflowing_tokens=True,229return_offsets_mapping=True,230padding="max_length",231)232sample_map = inputs.pop("overflow_to_sample_mapping")233example_ids = []234235for i in range(len(inputs["input_ids"])):236sample_idx = sample_map[i]237example_ids.append(examples["id"][sample_idx])238sequence_ids = inputs.sequence_ids(i)239offset = inputs["offset_mapping"][i]240inputs["offset_mapping"][i] = [241o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)242]243244inputs["example_id"] = example_ids245return inputs246247248def preprocess_train_function(examples, tokenizer):249inputs = tokenizer(250[q.strip() for q in examples["question"]],251examples["context"],252max_length=384,253truncation="only_second",254return_offsets_mapping=True,255padding="max_length",256)257258offset_mapping = inputs["offset_mapping"]259answers = examples["answers"]260start_positions = []261end_positions = []262263for i, (offset, answer) in enumerate(zip(offset_mapping, answers)):264start_char = answer["answer_start"][0]265end_char = start_char + len(answer["text"][0])266sequence_ids = inputs.sequence_ids(i)267268# Find the start and end of the context269idx = 0270while sequence_ids[idx] != 1:271idx += 1272context_start = idx273while sequence_ids[idx] == 1:274idx += 1275context_end = idx - 1276277# If the answer is not fully inside the context, label it (0, 0)278if offset[context_start][0] > end_char or offset[context_end][1] < start_char:279start_positions.append(0)280end_positions.append(0)281else:282# Otherwise it's the start and end token positions283idx = context_start284while idx <= context_end and offset[idx][0] <= start_char:285idx += 1286start_positions.append(idx - 1)287288idx = context_end289while idx >= context_start and offset[idx][1] >= end_char:290idx -= 1291end_positions.append(idx + 1)292293inputs["start_positions"] = start_positions294inputs["end_positions"] = end_positions295return inputs296297298def compute_metrics(start_logits, end_logits, features, examples):299n_best = 20300max_answer_length = 30301metric = evaluate.load("squad")302303example_to_features = collections.defaultdict(list)304for idx, feature in enumerate(features):305example_to_features[feature["example_id"]].append(idx)306307predicted_answers = []308# for example in ``tqdm`` (examples):309for example in examples:310example_id = example["id"]311context = example["context"]312answers = []313314# Loop through all features associated with that example315for feature_index in example_to_features[example_id]:316start_logit = start_logits[feature_index]317end_logit = end_logits[feature_index]318offsets = features[feature_index]["offset_mapping"]319320start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()321end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()322for start_index in start_indexes:323for end_index in end_indexes:324# Skip answers that are not fully in the context325if offsets[start_index] is None or offsets[end_index] is None:326continue327# Skip answers with a length that is either < 0328# or > max_answer_length329if (330end_index < start_index331or end_index - start_index + 1 > max_answer_length332):333continue334335answer = {336"text": context[337offsets[start_index][0] : offsets[end_index][1]338],339"logit_score": start_logit[start_index] + end_logit[end_index],340}341answers.append(answer)342343# Select the answer with the best score344if len(answers) > 0:345best_answer = max(answers, key=lambda x: x["logit_score"])346predicted_answers.append(347{"id": example_id, "prediction_text": best_answer["text"]}348)349else:350predicted_answers.append({"id": example_id, "prediction_text": ""})351352theoretical_answers = [353{"id": ex["id"], "answers": ex["answers"]} for ex in examples354]355return metric.compute(predictions=predicted_answers, references=theoretical_answers)356357358######################################################################359# Now that those are defined, we just need one additional helper function,360# which will help us benchmark our model.361#362363def measure_execution_time(model, batch_sizes, dataset):364dataset_for_model = dataset.remove_columns(["example_id", "offset_mapping"])365dataset_for_model.set_format("torch")366batch_size_to_time_sec = {}367for batch_size in batch_sizes:368batch = {369k: dataset_for_model[k][:batch_size].cuda()370for k in dataset_for_model.column_names371}372373with torch.no_grad():374baseline_predictions = model(**batch)375timer = benchmark.Timer(376stmt="model(**batch)", globals={"model": model, "batch": batch}377)378p50 = timer.blocked_autorange().median * 1000379batch_size_to_time_sec[batch_size] = p50380381model_c = torch.compile(model, fullgraph=True)382timer = benchmark.Timer(383stmt="model(**batch)", globals={"model": model_c, "batch": batch}384)385p50 = timer.blocked_autorange().median * 1000386batch_size_to_time_sec[f"{batch_size}_compile"] = p50387new_predictions = model_c(**batch)388389return batch_size_to_time_sec390391392393######################################################################394# We will get started by loading our model and tokenizer, and then setting395# up our dataset.396#397398# load model399model_name = "bert-base-cased"400tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)401model = transformers.AutoModelForQuestionAnswering.from_pretrained(model_name)402print(f"Loading tokenizer: {model_name}")403print(f"Loading model: {model_name}")404405# set up train and val dataset406squad_dataset = datasets.load_dataset("squad")407tokenized_squad_dataset = {}408tokenized_squad_dataset["train"] = squad_dataset["train"].map(409lambda x: preprocess_train_function(x, tokenizer), batched=True410)411tokenized_squad_dataset["validation"] = squad_dataset["validation"].map(412lambda x: preprocess_validation_function(x, tokenizer),413batched=True,414remove_columns=squad_dataset["train"].column_names,415)416data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer)417418419######################################################################420# Establishing a baseline421# =======================422#423# Next, we’ll train a quick baseline of our model on SQuAD. This task asks424# our model to identify spans, or segments of text, in a given context425# (Wikipedia articles) that answer a given question. Running the following426# code gives me an F1 score of 86.9. This is quite close to the reported427# NVIDIA score and the difference is likely due to BERT-base428# vs. BERT-large or fine-tuning hyperparameters.429#430431training_args = transformers.TrainingArguments(432"trainer",433num_train_epochs=1,434lr_scheduler_type="constant",435per_device_train_batch_size=32,436per_device_eval_batch_size=256,437logging_steps=50,438# Limit max steps for tutorial runners. Delete the below line to see the reported accuracy numbers.439max_steps=500,440report_to=None,441)442443trainer = transformers.Trainer(444model,445training_args,446train_dataset=tokenized_squad_dataset["train"],447eval_dataset=tokenized_squad_dataset["validation"],448data_collator=data_collator,449tokenizer=tokenizer,450)451452trainer.train()453454# batch sizes to compare for eval455batch_sizes = [4, 16, 64, 256]456# 2:4 sparsity require fp16, so we cast here for a fair comparison457with torch.autocast("cuda"):458with torch.no_grad():459predictions = trainer.predict(tokenized_squad_dataset["validation"])460start_logits, end_logits = predictions.predictions461fp16_baseline = compute_metrics(462start_logits,463end_logits,464tokenized_squad_dataset["validation"],465squad_dataset["validation"],466)467fp16_time = measure_execution_time(468model,469batch_sizes,470tokenized_squad_dataset["validation"],471)472473print("fp16", fp16_baseline)474print("cuda_fp16 time", fp16_time)475476import pandas as pd477df = pd.DataFrame(trainer.state.log_history)478df.plot.line(x='step', y='loss', title="Loss vs. # steps", ylabel="loss")479480481######################################################################482# Pruning BERT to be 2:4 sparse483# -----------------------------484#485# Now that we have our baseline, it’s time we prune BERT. There are many486# different pruning strategies, but one of the most common is **magnitude487# pruning**, which seeks to remove the weights with the lowest L1 norm.488# Magnitude pruning was used by NVIDIA in all their results and is a489# common baseline.490#491# To do this, we will use the ``torch.ao.pruning`` package, which contains492# a weight-norm (magnitude) sparsifier. These sparsifiers work by applying493# mask parametrizations to the weight tensors in a model. This lets them494# simulate sparsity by masking out the pruned weights.495#496# We’ll also have to decide what layers of the model to apply sparsity to,497# which in this case is all of the ``nn.Linear`` layers, except for the498# task-specific head outputs. That’s because semi-structured sparsity has499# `shape constraints <https://pytorch.org/docs/2.1/sparse.html#constructing-sparse-semi-structured-tensors>`_,500# and the task-specific ``nn.Linear`` layers do not satisfy them.501#502503sparsifier = WeightNormSparsifier(504# apply sparsity to all blocks505sparsity_level=1.0,506# shape of 4 elements is a block507sparse_block_shape=(1, 4),508# two zeros for every block of 4509zeros_per_block=2510)511512# add to config if ``nn.Linear`` and in the BERT model.513sparse_config = [514{"tensor_fqn": f"{fqn}.weight"}515for fqn, module in model.named_modules()516if isinstance(module, nn.Linear) and "layer" in fqn517]518519520######################################################################521# The first step for pruning the model is to insert parametrizations for522# masking the weights of the model. This is done by the prepare step.523# Anytime we try to access the ``.weight`` we will get ``mask * weight``524# instead.525#526527# Prepare the model, insert fake-sparsity parametrizations for training528sparsifier.prepare(model, sparse_config)529print(model.bert.encoder.layer[0].output)530531532######################################################################533# Then, we’ll take a single pruning step. All pruners implement a534# ``update_mask()`` method that updates the mask with the logic being535# determined by the pruner implementation. The step method calls this536# ``update_mask`` functions for the weights specified in the sparse537# config.538#539# We will also evaluate the model to show the accuracy degradation of540# zero-shot pruning, or pruning without fine-tuning / retraining.541#542543sparsifier.step()544with torch.autocast("cuda"):545with torch.no_grad():546predictions = trainer.predict(tokenized_squad_dataset["validation"])547pruned = compute_metrics(548*predictions.predictions,549tokenized_squad_dataset["validation"],550squad_dataset["validation"],551)552print("pruned eval metrics:", pruned)553554555######################################################################556# In this state, we can start fine-tuning the model, updating the elements557# that wouldn’t be pruned to better account for the accuracy loss. Once558# we’ve reached a satisfied state, we can call ``squash_mask`` to fuse the559# mask and the weight together. This will remove the parametrizations and560# we are left with a zeroed-out 2:4 dense model.561#562563trainer.train()564sparsifier.squash_mask()565torch.set_printoptions(edgeitems=4)566print(model.bert.encoder.layer[0].intermediate.dense.weight[:8, :8])567568df["sparse_loss"] = pd.DataFrame(trainer.state.log_history)["loss"]569df.plot.line(x='step', y=["loss", "sparse_loss"], title="Loss vs. # steps", ylabel="loss")570571572######################################################################573# Accelerating 2:4 sparse models for inference574# --------------------------------------------575#576# Now that we have a model in this format, we can accelerate it for577# inference just like in the QuickStart Guide.578#579580model = model.cuda().half()581# accelerate for sparsity582for fqn, module in model.named_modules():583if isinstance(module, nn.Linear) and "layer" in fqn:584module.weight = nn.Parameter(to_sparse_semi_structured(module.weight))585586with torch.no_grad():587predictions = trainer.predict(tokenized_squad_dataset["validation"])588start_logits, end_logits = predictions.predictions589metrics_sparse = compute_metrics(590start_logits,591end_logits,592tokenized_squad_dataset["validation"],593squad_dataset["validation"],594)595print("sparse eval metrics: ", metrics_sparse)596sparse_perf = measure_execution_time(597model,598batch_sizes,599tokenized_squad_dataset["validation"],600)601print("sparse perf metrics: ", sparse_perf)602603604######################################################################605# Retraining our model after magnitude pruning has recovered nearly all of606# the F1 that has been lost when the model was pruned. At the same time we607# have achieved a 1.28x speedup for ``bs=16``. Note that not all shapes are608# amenable to performance improvements. When batch sizes are small and609# limited time is spent in compute sparse kernels may be slower than their610# dense counterparts.611#612# Because semi-structured sparsity is implemented as a tensor subclass, it613# is compatible with ``torch.compile``. When composed with614# ``to_sparse_semi_structured``, we are able to achieve a total 2x speedup615# on BERT.616#617# .. table::618#619# +--------------------+--------+--------------+-----------------+-----------+620# | Metrics | fp16 | 2:4 sparse | delta / speedup | compiled |621# +====================+========+==============+=================+===========+622# | Exact Match (%) | 78.53 | 78.44 | -0.09 | |623# +--------------------+--------+--------------+-----------------+-----------+624# | F1 (%) | 86.93 | 86.49 | -0.44 | |625# +--------------------+--------+--------------+-----------------+-----------+626# | Time (bs=4) | 11.10 | 15.54 | 0.71x | no |627# +--------------------+--------+--------------+-----------------+-----------+628# | Time (bs=16) | 19.35 | 15.74 | 1.23x | no |629# +--------------------+--------+--------------+-----------------+-----------+630# | Time (bs=64) | 72.71 | 59.41 | 1.22x | no |631# +--------------------+--------+--------------+-----------------+-----------+632# | Time (bs=256) | 286.65 | 247.63 | 1.14x | no |633# +--------------------+--------+--------------+-----------------+-----------+634# | Time (bs=4) | 7.59 | 7.46 | 1.02x | yes |635# +--------------------+--------+--------------+-----------------+-----------+636# | Time (bs=16) | 11.47 | 9.68 | 1.18x | yes |637# +--------------------+--------+--------------+-----------------+-----------+638# | Time (bs=64) | 41.57 | 36.92 | 1.13x | yes |639# +--------------------+--------+--------------+-----------------+-----------+640# | Time (bs=256) | 159.22 | 142.23 | 1.12x | yes |641# +--------------------+--------+--------------+-----------------+-----------+642#643# Conclusion644# ==========645#646# In this tutorial, we have shown how to prune BERT to be 2:4 sparse and647# how to accelerate a 2:4 sparse model for inference. By taking advantage648# of our ``SparseSemiStructuredTensor`` subclass, we were able to achieve a649# 1.3x speedup over the fp16 baseline, and up to 2x with650# ``torch.compile``. We also demonstrated the benefits of 2:4 sparsity by651# fine-tuning BERT to recover any lost F1 (86.92 dense vs 86.48 sparse).652#653654655