Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/advanced_source/usb_semisup_learn.py
Views: 712
"""1Semi-Supervised Learning using USB built upon PyTorch2=====================================================34**Author**: `Hao Chen <https://github.com/Hhhhhhao>`_56Unified Semi-supervised learning Benchmark (USB) is a semi-supervised7learning (SSL) framework built upon PyTorch.8Based on Datasets and Modules provided by PyTorch, USB becomes a flexible,9modular, and easy-to-use framework for semi-supervised learning.10It supports a variety of semi-supervised learning algorithms, including11``FixMatch``, ``FreeMatch``, ``DeFixMatch``, ``SoftMatch``, and so on.12It also supports a variety of imbalanced semi-supervised learning algorithms.13The benchmark results across different datasets of computer vision, natural14language processing, and speech processing are included in USB.1516This tutorial will walk you through the basics of using the USB lighting17package.18Let's get started by training a ``FreeMatch``/``SoftMatch`` model on19CIFAR-10 using pretrained Vision Transformers (ViT)!20And we will show it is easy to change the semi-supervised algorithm and train21on imbalanced datasets.222324.. figure:: /_static/img/usb_semisup_learn/code.png25:alt: USB framework illustration26"""272829######################################################################30# Introduction to ``FreeMatch`` and ``SoftMatch`` in Semi-Supervised Learning31# ---------------------------------------------------------------------------32#33# Here we provide a brief introduction to ``FreeMatch`` and ``SoftMatch``.34# First, we introduce a famous baseline for semi-supervised learning called ``FixMatch``.35# ``FixMatch`` is a very simple framework for semi-supervised learning, where it36# utilizes a strong augmentation to generate pseudo labels for unlabeled data.37# It adopts a confidence thresholding strategy to filter out the low-confidence38# pseudo labels with a fixed threshold set.39# ``FreeMatch`` and ``SoftMatch`` are two algorithms that improve upon ``FixMatch``.40# ``FreeMatch`` proposes adaptive thresholding strategy to replace the fixed41# thresholding strategy in ``FixMatch``. The adaptive thresholding progressively42# increases the threshold according to the learning status of the model on each43# class. ``SoftMatch`` absorbs the idea of confidence thresholding as an44# weighting mechanism. It proposes a Gaussian weighting mechanism to overcome45# the quantity-quality trade-off in pseudo-labels. In this tutorial, we will46# use USB to train ``FreeMatch`` and ``SoftMatch``.474849######################################################################50# Use USB to Train ``FreeMatch``/``SoftMatch`` on CIFAR-10 with only 40 labels51# ----------------------------------------------------------------------------52#53# USB is easy to use and extend, affordable to small groups, and comprehensive54# for developing and evaluating SSL algorithms.55# USB provides the implementation of 14 SSL algorithms based on Consistency56# Regularization, and 15 tasks for evaluation from CV, NLP, and Audio domain.57# It has a modular design that allows users to easily extend the package by58# adding new algorithms and tasks.59# It also supports a Python API for easier adaptation to different SSL60# algorithms on new data.61#62#63# Now, let's use USB to train ``FreeMatch`` and ``SoftMatch`` on CIFAR-10.64# First, we need to install USB package ``semilearn`` and import necessary API65# functions from USB.66# If you are running this in Google Colab, install ``semilearn`` by running:67# ``!pip install semilearn``.68#69# Below is a list of functions we will use from ``semilearn``:70#71# - ``get_dataset`` to load dataset, here we use CIFAR-1072# - ``get_data_loader`` to create train (labeled and unlabeled) and test data73# loaders, the train unlabeled loaders will provide both strong and weak74# augmentation of unlabeled data75# - ``get_net_builder`` to create a model, here we use pretrained ViT76# - ``get_algorithm`` to create the semi-supervised learning algorithm,77# here we use ``FreeMatch`` and ``SoftMatch``78# - ``get_config``: to get default configuration of the algorithm79# - ``Trainer``: a Trainer class for training and evaluating the80# algorithm on dataset81#82# Note that a CUDA-enabled backend is required for training with the ``semilearn`` package.83# See `Enabling CUDA in Google Colab <https://pytorch.org/tutorials/beginner/colab#enabling-cuda>`__ for instructions84# on enabling CUDA in Google Colab.85#86import semilearn87from semilearn import get_dataset, get_data_loader, get_net_builder, get_algorithm, get_config, Trainer8889######################################################################90# After importing necessary functions, we first set the hyper-parameters of the91# algorithm.92#93config = {94'algorithm': 'freematch',95'net': 'vit_tiny_patch2_32',96'use_pretrain': True,97'pretrain_path': 'https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth',9899# optimization configs100'epoch': 1,101'num_train_iter': 500,102'num_eval_iter': 500,103'num_log_iter': 50,104'optim': 'AdamW',105'lr': 5e-4,106'layer_decay': 0.5,107'batch_size': 16,108'eval_batch_size': 16,109110111# dataset configs112'dataset': 'cifar10',113'num_labels': 40,114'num_classes': 10,115'img_size': 32,116'crop_ratio': 0.875,117'data_dir': './data',118'ulb_samples_per_class': None,119120# algorithm specific configs121'hard_label': True,122'T': 0.5,123'ema_p': 0.999,124'ent_loss_ratio': 0.001,125'uratio': 2,126'ulb_loss_ratio': 1.0,127128# device configs129'gpu': 0,130'world_size': 1,131'distributed': False,132"num_workers": 4,133}134config = get_config(config)135136137######################################################################138# Then, we load the dataset and create data loaders for training and testing.139# And we specify the model and algorithm to use.140#141dataset_dict = get_dataset(config, config.algorithm, config.dataset, config.num_labels, config.num_classes, data_dir=config.data_dir, include_lb_to_ulb=config.include_lb_to_ulb)142train_lb_loader = get_data_loader(config, dataset_dict['train_lb'], config.batch_size)143train_ulb_loader = get_data_loader(config, dataset_dict['train_ulb'], int(config.batch_size * config.uratio))144eval_loader = get_data_loader(config, dataset_dict['eval'], config.eval_batch_size)145algorithm = get_algorithm(config, get_net_builder(config.net, from_name=False), tb_log=None, logger=None)146147148######################################################################149# We can start training the algorithms on CIFAR-10 with 40 labels now.150# We train for 500 iterations and evaluate every 500 iterations.151#152trainer = Trainer(config, algorithm)153trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)154155156######################################################################157# Finally, let's evaluate the trained model on the validation set.158# After training 500 iterations with ``FreeMatch`` on only 40 labels of159# CIFAR-10, we obtain a classifier that achieves around 87% accuracy on the validation set.160trainer.evaluate(eval_loader)161162163164######################################################################165# Use USB to Train ``SoftMatch`` with specific imbalanced algorithm on imbalanced CIFAR-10166# ----------------------------------------------------------------------------------------167#168# Now let's say we have imbalanced labeled set and unlabeled set of CIFAR-10,169# and we want to train a ``SoftMatch`` model on it.170# We create an imbalanced labeled set and imbalanced unlabeled set of CIFAR-10,171# by setting the ``lb_imb_ratio`` and ``ulb_imb_ratio`` to 10.172# Also, we replace the ``algorithm`` with ``softmatch`` and set the ``imbalanced``173# to ``True``.174#175config = {176'algorithm': 'softmatch',177'net': 'vit_tiny_patch2_32',178'use_pretrain': True,179'pretrain_path': 'https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth',180181# optimization configs182'epoch': 1,183'num_train_iter': 500,184'num_eval_iter': 500,185'num_log_iter': 50,186'optim': 'AdamW',187'lr': 5e-4,188'layer_decay': 0.5,189'batch_size': 16,190'eval_batch_size': 16,191192193# dataset configs194'dataset': 'cifar10',195'num_labels': 1500,196'num_classes': 10,197'img_size': 32,198'crop_ratio': 0.875,199'data_dir': './data',200'ulb_samples_per_class': None,201'lb_imb_ratio': 10,202'ulb_imb_ratio': 10,203'ulb_num_labels': 3000,204205# algorithm specific configs206'hard_label': True,207'T': 0.5,208'ema_p': 0.999,209'ent_loss_ratio': 0.001,210'uratio': 2,211'ulb_loss_ratio': 1.0,212213# device configs214'gpu': 0,215'world_size': 1,216'distributed': False,217"num_workers": 4,218}219config = get_config(config)220221######################################################################222# Then, we re-load the dataset and create data loaders for training and testing.223# And we specify the model and algorithm to use.224#225dataset_dict = get_dataset(config, config.algorithm, config.dataset, config.num_labels, config.num_classes, data_dir=config.data_dir, include_lb_to_ulb=config.include_lb_to_ulb)226train_lb_loader = get_data_loader(config, dataset_dict['train_lb'], config.batch_size)227train_ulb_loader = get_data_loader(config, dataset_dict['train_ulb'], int(config.batch_size * config.uratio))228eval_loader = get_data_loader(config, dataset_dict['eval'], config.eval_batch_size)229algorithm = get_algorithm(config, get_net_builder(config.net, from_name=False), tb_log=None, logger=None)230231232######################################################################233# We can start Train the algorithms on CIFAR-10 with 40 labels now.234# We train for 500 iterations and evaluate every 500 iterations.235#236trainer = Trainer(config, algorithm)237trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)238239240######################################################################241# Finally, let's evaluate the trained model on the validation set.242#243trainer.evaluate(eval_loader)244245246247######################################################################248# References:249# - [1] USB: https://github.com/microsoft/Semi-supervised-learning250# - [2] Kihyuk Sohn et al. FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence251# - [3] Yidong Wang et al. FreeMatch: Self-adaptive Thresholding for Semi-supervised Learning252# - [4] Hao Chen et al. SoftMatch: Addressing the Quantity-Quality Trade-off in Semi-supervised Learning253254255