Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/advanced_source/cpp_cuda_graphs/mnist.cpp
Views: 713
#include <ATen/cuda/CUDAEvent.h>1#include <ATen/cuda/CUDAGraph.h>2#include <c10/cuda/CUDAStream.h>3#include <torch/torch.h>45#include <cstddef>6#include <cstdio>7#include <iostream>8#include <string>9#include <vector>1011// Where to find the MNIST dataset.12const char* kDataRoot = "./data";1314// The batch size for training.15const int64_t kTrainBatchSize = 64;1617// The batch size for testing.18const int64_t kTestBatchSize = 1000;1920// The number of epochs to train.21const int64_t kNumberOfEpochs = 10;2223// After how many batches to log a new update with the loss value.24const int64_t kLogInterval = 10;2526// Model that we will be training27struct Net : torch::nn::Module {28Net()29: conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)),30conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)),31fc1(320, 50),32fc2(50, 10) {33register_module("conv1", conv1);34register_module("conv2", conv2);35register_module("conv2_drop", conv2_drop);36register_module("fc1", fc1);37register_module("fc2", fc2);38}3940torch::Tensor forward(torch::Tensor x) {41x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));42x = torch::relu(43torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));44x = x.view({-1, 320});45x = torch::relu(fc1->forward(x));46x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training());47x = fc2->forward(x);48return torch::log_softmax(x, /*dim=*/1);49}5051torch::nn::Conv2d conv1;52torch::nn::Conv2d conv2;53torch::nn::Dropout2d conv2_drop;54torch::nn::Linear fc1;55torch::nn::Linear fc2;56};5758void stream_sync(59at::cuda::CUDAStream& dependency,60at::cuda::CUDAStream& dependent) {61at::cuda::CUDAEvent cuda_ev;62cuda_ev.record(dependency);63cuda_ev.block(dependent);64}6566void training_step(67Net& model,68torch::optim::Optimizer& optimizer,69torch::Tensor& data,70torch::Tensor& targets,71torch::Tensor& output,72torch::Tensor& loss) {73optimizer.zero_grad();74output = model.forward(data);75loss = torch::nll_loss(output, targets);76loss.backward();77optimizer.step();78}7980void capture_train_graph(81Net& model,82torch::optim::Optimizer& optimizer,83torch::Tensor& data,84torch::Tensor& targets,85torch::Tensor& output,86torch::Tensor& loss,87at::cuda::CUDAGraph& graph,88const short num_warmup_iters = 7) {89model.train();9091auto warmupStream = at::cuda::getStreamFromPool();92auto captureStream = at::cuda::getStreamFromPool();93auto legacyStream = at::cuda::getCurrentCUDAStream();9495at::cuda::setCurrentCUDAStream(warmupStream);9697stream_sync(legacyStream, warmupStream);9899for (C10_UNUSED const auto iter : c10::irange(num_warmup_iters)) {100training_step(model, optimizer, data, targets, output, loss);101}102103stream_sync(warmupStream, captureStream);104at::cuda::setCurrentCUDAStream(captureStream);105106graph.capture_begin();107training_step(model, optimizer, data, targets, output, loss);108graph.capture_end();109110stream_sync(captureStream, legacyStream);111at::cuda::setCurrentCUDAStream(legacyStream);112}113114template <typename DataLoader>115void train(116size_t epoch,117Net& model,118torch::Device device,119DataLoader& data_loader,120torch::optim::Optimizer& optimizer,121size_t dataset_size,122torch::Tensor& data,123torch::Tensor& targets,124torch::Tensor& output,125torch::Tensor& loss,126at::cuda::CUDAGraph& graph,127bool use_graph) {128model.train();129130size_t batch_idx = 0;131132for (const auto& batch : data_loader) {133if (batch.data.size(0) != kTrainBatchSize ||134batch.target.size(0) != kTrainBatchSize) {135continue;136}137138data.copy_(batch.data);139targets.copy_(batch.target);140141if (use_graph) {142graph.replay();143} else {144training_step(model, optimizer, data, targets, output, loss);145}146147if (batch_idx++ % kLogInterval == 0) {148float train_loss = loss.item<float>();149std::cout << "\rTrain Epoch:" << epoch << " ["150<< batch_idx * batch.data.size(0) << "/" << dataset_size151<< "] Loss: " << train_loss;152}153}154}155156void test_step(157Net& model,158torch::Tensor& data,159torch::Tensor& targets,160torch::Tensor& output,161torch::Tensor& loss) {162output = model.forward(data);163loss = torch::nll_loss(output, targets, {}, torch::Reduction::Sum);164}165166void capture_test_graph(167Net& model,168torch::Tensor& data,169torch::Tensor& targets,170torch::Tensor& output,171torch::Tensor& loss,172torch::Tensor& total_loss,173torch::Tensor& total_correct,174at::cuda::CUDAGraph& graph,175const int num_warmup_iters = 7) {176torch::NoGradGuard no_grad;177model.eval();178179auto warmupStream = at::cuda::getStreamFromPool();180auto captureStream = at::cuda::getStreamFromPool();181auto legacyStream = at::cuda::getCurrentCUDAStream();182183at::cuda::setCurrentCUDAStream(warmupStream);184stream_sync(captureStream, legacyStream);185186for (C10_UNUSED const auto iter : c10::irange(num_warmup_iters)) {187test_step(model, data, targets, output, loss);188total_loss += loss;189total_correct += output.argmax(1).eq(targets).sum();190}191192stream_sync(warmupStream, captureStream);193at::cuda::setCurrentCUDAStream(captureStream);194195graph.capture_begin();196test_step(model, data, targets, output, loss);197graph.capture_end();198199stream_sync(captureStream, legacyStream);200at::cuda::setCurrentCUDAStream(legacyStream);201}202203template <typename DataLoader>204void test(205Net& model,206torch::Device device,207DataLoader& data_loader,208size_t dataset_size,209torch::Tensor& data,210torch::Tensor& targets,211torch::Tensor& output,212torch::Tensor& loss,213torch::Tensor& total_loss,214torch::Tensor& total_correct,215at::cuda::CUDAGraph& graph,216bool use_graph) {217torch::NoGradGuard no_grad;218219model.eval();220loss.zero_();221total_loss.zero_();222total_correct.zero_();223224for (const auto& batch : data_loader) {225if (batch.data.size(0) != kTestBatchSize ||226batch.target.size(0) != kTestBatchSize) {227continue;228}229data.copy_(batch.data);230targets.copy_(batch.target);231232if (use_graph) {233graph.replay();234} else {235test_step(model, data, targets, output, loss);236}237total_loss += loss;238total_correct += output.argmax(1).eq(targets).sum();239}240241float test_loss = total_loss.item<float>() / dataset_size;242float test_accuracy =243static_cast<float>(total_correct.item<int64_t>()) / dataset_size;244245std::cout << std::endl246<< "Test set: Average loss: " << test_loss247<< " | Accuracy: " << test_accuracy << std::endl;248}249250int main(int argc, char* argv[]) {251if (!torch::cuda::is_available()) {252std::cout << "CUDA is not available!" << std::endl;253return -1;254}255256bool use_train_graph = false;257bool use_test_graph = false;258259std::vector<std::string> arguments(argv + 1, argv + argc);260for (std::string& arg : arguments) {261if (arg == "--use-train-graph") {262std::cout << "Using CUDA Graph for training." << std::endl;263use_train_graph = true;264}265if (arg == "--use-test-graph") {266std::cout << "Using CUDA Graph for testing." << std::endl;267use_test_graph = true;268}269}270271torch::manual_seed(1);272torch::cuda::manual_seed(1);273torch::Device device(torch::kCUDA);274275Net model;276model.to(device);277278auto train_dataset =279torch::data::datasets::MNIST(kDataRoot)280.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))281.map(torch::data::transforms::Stack<>());282const size_t train_dataset_size = train_dataset.size().value();283auto train_loader =284torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(285std::move(train_dataset), kTrainBatchSize);286287auto test_dataset =288torch::data::datasets::MNIST(289kDataRoot, torch::data::datasets::MNIST::Mode::kTest)290.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))291.map(torch::data::transforms::Stack<>());292const size_t test_dataset_size = test_dataset.size().value();293auto test_loader =294torch::data::make_data_loader(std::move(test_dataset), kTestBatchSize);295296torch::optim::SGD optimizer(297model.parameters(), torch::optim::SGDOptions(0.01).momentum(0.5));298299torch::TensorOptions FloatCUDA =300torch::TensorOptions(device).dtype(torch::kFloat);301torch::TensorOptions LongCUDA =302torch::TensorOptions(device).dtype(torch::kLong);303304torch::Tensor train_data =305torch::zeros({kTrainBatchSize, 1, 28, 28}, FloatCUDA);306torch::Tensor train_targets = torch::zeros({kTrainBatchSize}, LongCUDA);307torch::Tensor train_output = torch::zeros({1}, FloatCUDA);308torch::Tensor train_loss = torch::zeros({1}, FloatCUDA);309310torch::Tensor test_data =311torch::zeros({kTestBatchSize, 1, 28, 28}, FloatCUDA);312torch::Tensor test_targets = torch::zeros({kTestBatchSize}, LongCUDA);313torch::Tensor test_output = torch::zeros({1}, FloatCUDA);314torch::Tensor test_loss = torch::zeros({1}, FloatCUDA);315torch::Tensor test_total_loss = torch::zeros({1}, FloatCUDA);316torch::Tensor test_total_correct = torch::zeros({1}, LongCUDA);317318at::cuda::CUDAGraph train_graph;319at::cuda::CUDAGraph test_graph;320321capture_train_graph(322model,323optimizer,324train_data,325train_targets,326train_output,327train_loss,328train_graph);329330capture_test_graph(331model,332test_data,333test_targets,334test_output,335test_loss,336test_total_loss,337test_total_correct,338test_graph);339340for (size_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {341train(342epoch,343model,344device,345*train_loader,346optimizer,347train_dataset_size,348train_data,349train_targets,350train_output,351train_loss,352train_graph,353use_train_graph);354test(355model,356device,357*test_loader,358test_dataset_size,359test_data,360test_targets,361test_output,362test_loss,363test_total_loss,364test_total_correct,365test_graph,366use_test_graph);367}368369std::cout << " Training/testing complete" << std::endl;370return 0;371}372373374