CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/usb_semisup_learn.py
Views: 712
1
"""
2
Semi-Supervised Learning using USB built upon PyTorch
3
=====================================================
4
5
**Author**: `Hao Chen <https://github.com/Hhhhhhao>`_
6
7
Unified Semi-supervised learning Benchmark (USB) is a semi-supervised
8
learning (SSL) framework built upon PyTorch.
9
Based on Datasets and Modules provided by PyTorch, USB becomes a flexible,
10
modular, and easy-to-use framework for semi-supervised learning.
11
It supports a variety of semi-supervised learning algorithms, including
12
``FixMatch``, ``FreeMatch``, ``DeFixMatch``, ``SoftMatch``, and so on.
13
It also supports a variety of imbalanced semi-supervised learning algorithms.
14
The benchmark results across different datasets of computer vision, natural
15
language processing, and speech processing are included in USB.
16
17
This tutorial will walk you through the basics of using the USB lighting
18
package.
19
Let's get started by training a ``FreeMatch``/``SoftMatch`` model on
20
CIFAR-10 using pretrained Vision Transformers (ViT)!
21
And we will show it is easy to change the semi-supervised algorithm and train
22
on imbalanced datasets.
23
24
25
.. figure:: /_static/img/usb_semisup_learn/code.png
26
:alt: USB framework illustration
27
"""
28
29
30
######################################################################
31
# Introduction to ``FreeMatch`` and ``SoftMatch`` in Semi-Supervised Learning
32
# ---------------------------------------------------------------------------
33
#
34
# Here we provide a brief introduction to ``FreeMatch`` and ``SoftMatch``.
35
# First, we introduce a famous baseline for semi-supervised learning called ``FixMatch``.
36
# ``FixMatch`` is a very simple framework for semi-supervised learning, where it
37
# utilizes a strong augmentation to generate pseudo labels for unlabeled data.
38
# It adopts a confidence thresholding strategy to filter out the low-confidence
39
# pseudo labels with a fixed threshold set.
40
# ``FreeMatch`` and ``SoftMatch`` are two algorithms that improve upon ``FixMatch``.
41
# ``FreeMatch`` proposes adaptive thresholding strategy to replace the fixed
42
# thresholding strategy in ``FixMatch``. The adaptive thresholding progressively
43
# increases the threshold according to the learning status of the model on each
44
# class. ``SoftMatch`` absorbs the idea of confidence thresholding as an
45
# weighting mechanism. It proposes a Gaussian weighting mechanism to overcome
46
# the quantity-quality trade-off in pseudo-labels. In this tutorial, we will
47
# use USB to train ``FreeMatch`` and ``SoftMatch``.
48
49
50
######################################################################
51
# Use USB to Train ``FreeMatch``/``SoftMatch`` on CIFAR-10 with only 40 labels
52
# ----------------------------------------------------------------------------
53
#
54
# USB is easy to use and extend, affordable to small groups, and comprehensive
55
# for developing and evaluating SSL algorithms.
56
# USB provides the implementation of 14 SSL algorithms based on Consistency
57
# Regularization, and 15 tasks for evaluation from CV, NLP, and Audio domain.
58
# It has a modular design that allows users to easily extend the package by
59
# adding new algorithms and tasks.
60
# It also supports a Python API for easier adaptation to different SSL
61
# algorithms on new data.
62
#
63
#
64
# Now, let's use USB to train ``FreeMatch`` and ``SoftMatch`` on CIFAR-10.
65
# First, we need to install USB package ``semilearn`` and import necessary API
66
# functions from USB.
67
# If you are running this in Google Colab, install ``semilearn`` by running:
68
# ``!pip install semilearn``.
69
#
70
# Below is a list of functions we will use from ``semilearn``:
71
#
72
# - ``get_dataset`` to load dataset, here we use CIFAR-10
73
# - ``get_data_loader`` to create train (labeled and unlabeled) and test data
74
# loaders, the train unlabeled loaders will provide both strong and weak
75
# augmentation of unlabeled data
76
# - ``get_net_builder`` to create a model, here we use pretrained ViT
77
# - ``get_algorithm`` to create the semi-supervised learning algorithm,
78
# here we use ``FreeMatch`` and ``SoftMatch``
79
# - ``get_config``: to get default configuration of the algorithm
80
# - ``Trainer``: a Trainer class for training and evaluating the
81
# algorithm on dataset
82
#
83
# Note that a CUDA-enabled backend is required for training with the ``semilearn`` package.
84
# See `Enabling CUDA in Google Colab <https://pytorch.org/tutorials/beginner/colab#enabling-cuda>`__ for instructions
85
# on enabling CUDA in Google Colab.
86
#
87
import semilearn
88
from semilearn import get_dataset, get_data_loader, get_net_builder, get_algorithm, get_config, Trainer
89
90
######################################################################
91
# After importing necessary functions, we first set the hyper-parameters of the
92
# algorithm.
93
#
94
config = {
95
'algorithm': 'freematch',
96
'net': 'vit_tiny_patch2_32',
97
'use_pretrain': True,
98
'pretrain_path': 'https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth',
99
100
# optimization configs
101
'epoch': 1,
102
'num_train_iter': 500,
103
'num_eval_iter': 500,
104
'num_log_iter': 50,
105
'optim': 'AdamW',
106
'lr': 5e-4,
107
'layer_decay': 0.5,
108
'batch_size': 16,
109
'eval_batch_size': 16,
110
111
112
# dataset configs
113
'dataset': 'cifar10',
114
'num_labels': 40,
115
'num_classes': 10,
116
'img_size': 32,
117
'crop_ratio': 0.875,
118
'data_dir': './data',
119
'ulb_samples_per_class': None,
120
121
# algorithm specific configs
122
'hard_label': True,
123
'T': 0.5,
124
'ema_p': 0.999,
125
'ent_loss_ratio': 0.001,
126
'uratio': 2,
127
'ulb_loss_ratio': 1.0,
128
129
# device configs
130
'gpu': 0,
131
'world_size': 1,
132
'distributed': False,
133
"num_workers": 4,
134
}
135
config = get_config(config)
136
137
138
######################################################################
139
# Then, we load the dataset and create data loaders for training and testing.
140
# And we specify the model and algorithm to use.
141
#
142
dataset_dict = get_dataset(config, config.algorithm, config.dataset, config.num_labels, config.num_classes, data_dir=config.data_dir, include_lb_to_ulb=config.include_lb_to_ulb)
143
train_lb_loader = get_data_loader(config, dataset_dict['train_lb'], config.batch_size)
144
train_ulb_loader = get_data_loader(config, dataset_dict['train_ulb'], int(config.batch_size * config.uratio))
145
eval_loader = get_data_loader(config, dataset_dict['eval'], config.eval_batch_size)
146
algorithm = get_algorithm(config, get_net_builder(config.net, from_name=False), tb_log=None, logger=None)
147
148
149
######################################################################
150
# We can start training the algorithms on CIFAR-10 with 40 labels now.
151
# We train for 500 iterations and evaluate every 500 iterations.
152
#
153
trainer = Trainer(config, algorithm)
154
trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)
155
156
157
######################################################################
158
# Finally, let's evaluate the trained model on the validation set.
159
# After training 500 iterations with ``FreeMatch`` on only 40 labels of
160
# CIFAR-10, we obtain a classifier that achieves around 87% accuracy on the validation set.
161
trainer.evaluate(eval_loader)
162
163
164
165
######################################################################
166
# Use USB to Train ``SoftMatch`` with specific imbalanced algorithm on imbalanced CIFAR-10
167
# ----------------------------------------------------------------------------------------
168
#
169
# Now let's say we have imbalanced labeled set and unlabeled set of CIFAR-10,
170
# and we want to train a ``SoftMatch`` model on it.
171
# We create an imbalanced labeled set and imbalanced unlabeled set of CIFAR-10,
172
# by setting the ``lb_imb_ratio`` and ``ulb_imb_ratio`` to 10.
173
# Also, we replace the ``algorithm`` with ``softmatch`` and set the ``imbalanced``
174
# to ``True``.
175
#
176
config = {
177
'algorithm': 'softmatch',
178
'net': 'vit_tiny_patch2_32',
179
'use_pretrain': True,
180
'pretrain_path': 'https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth',
181
182
# optimization configs
183
'epoch': 1,
184
'num_train_iter': 500,
185
'num_eval_iter': 500,
186
'num_log_iter': 50,
187
'optim': 'AdamW',
188
'lr': 5e-4,
189
'layer_decay': 0.5,
190
'batch_size': 16,
191
'eval_batch_size': 16,
192
193
194
# dataset configs
195
'dataset': 'cifar10',
196
'num_labels': 1500,
197
'num_classes': 10,
198
'img_size': 32,
199
'crop_ratio': 0.875,
200
'data_dir': './data',
201
'ulb_samples_per_class': None,
202
'lb_imb_ratio': 10,
203
'ulb_imb_ratio': 10,
204
'ulb_num_labels': 3000,
205
206
# algorithm specific configs
207
'hard_label': True,
208
'T': 0.5,
209
'ema_p': 0.999,
210
'ent_loss_ratio': 0.001,
211
'uratio': 2,
212
'ulb_loss_ratio': 1.0,
213
214
# device configs
215
'gpu': 0,
216
'world_size': 1,
217
'distributed': False,
218
"num_workers": 4,
219
}
220
config = get_config(config)
221
222
######################################################################
223
# Then, we re-load the dataset and create data loaders for training and testing.
224
# And we specify the model and algorithm to use.
225
#
226
dataset_dict = get_dataset(config, config.algorithm, config.dataset, config.num_labels, config.num_classes, data_dir=config.data_dir, include_lb_to_ulb=config.include_lb_to_ulb)
227
train_lb_loader = get_data_loader(config, dataset_dict['train_lb'], config.batch_size)
228
train_ulb_loader = get_data_loader(config, dataset_dict['train_ulb'], int(config.batch_size * config.uratio))
229
eval_loader = get_data_loader(config, dataset_dict['eval'], config.eval_batch_size)
230
algorithm = get_algorithm(config, get_net_builder(config.net, from_name=False), tb_log=None, logger=None)
231
232
233
######################################################################
234
# We can start Train the algorithms on CIFAR-10 with 40 labels now.
235
# We train for 500 iterations and evaluate every 500 iterations.
236
#
237
trainer = Trainer(config, algorithm)
238
trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)
239
240
241
######################################################################
242
# Finally, let's evaluate the trained model on the validation set.
243
#
244
trainer.evaluate(eval_loader)
245
246
247
248
######################################################################
249
# References:
250
# - [1] USB: https://github.com/microsoft/Semi-supervised-learning
251
# - [2] Kihyuk Sohn et al. FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
252
# - [3] Yidong Wang et al. FreeMatch: Self-adaptive Thresholding for Semi-supervised Learning
253
# - [4] Hao Chen et al. SoftMatch: Addressing the Quantity-Quality Trade-off in Semi-supervised Learning
254
255