CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/cpp_cuda_graphs/mnist.cpp
Views: 494
1
#include <ATen/cuda/CUDAEvent.h>
2
#include <ATen/cuda/CUDAGraph.h>
3
#include <c10/cuda/CUDAStream.h>
4
#include <torch/torch.h>
5
6
#include <cstddef>
7
#include <cstdio>
8
#include <iostream>
9
#include <string>
10
#include <vector>
11
12
// Where to find the MNIST dataset.
13
const char* kDataRoot = "./data";
14
15
// The batch size for training.
16
const int64_t kTrainBatchSize = 64;
17
18
// The batch size for testing.
19
const int64_t kTestBatchSize = 1000;
20
21
// The number of epochs to train.
22
const int64_t kNumberOfEpochs = 10;
23
24
// After how many batches to log a new update with the loss value.
25
const int64_t kLogInterval = 10;
26
27
// Model that we will be training
28
struct Net : torch::nn::Module {
29
Net()
30
: conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)),
31
conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)),
32
fc1(320, 50),
33
fc2(50, 10) {
34
register_module("conv1", conv1);
35
register_module("conv2", conv2);
36
register_module("conv2_drop", conv2_drop);
37
register_module("fc1", fc1);
38
register_module("fc2", fc2);
39
}
40
41
torch::Tensor forward(torch::Tensor x) {
42
x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));
43
x = torch::relu(
44
torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));
45
x = x.view({-1, 320});
46
x = torch::relu(fc1->forward(x));
47
x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training());
48
x = fc2->forward(x);
49
return torch::log_softmax(x, /*dim=*/1);
50
}
51
52
torch::nn::Conv2d conv1;
53
torch::nn::Conv2d conv2;
54
torch::nn::Dropout2d conv2_drop;
55
torch::nn::Linear fc1;
56
torch::nn::Linear fc2;
57
};
58
59
void stream_sync(
60
at::cuda::CUDAStream& dependency,
61
at::cuda::CUDAStream& dependent) {
62
at::cuda::CUDAEvent cuda_ev;
63
cuda_ev.record(dependency);
64
cuda_ev.block(dependent);
65
}
66
67
void training_step(
68
Net& model,
69
torch::optim::Optimizer& optimizer,
70
torch::Tensor& data,
71
torch::Tensor& targets,
72
torch::Tensor& output,
73
torch::Tensor& loss) {
74
optimizer.zero_grad();
75
output = model.forward(data);
76
loss = torch::nll_loss(output, targets);
77
loss.backward();
78
optimizer.step();
79
}
80
81
void capture_train_graph(
82
Net& model,
83
torch::optim::Optimizer& optimizer,
84
torch::Tensor& data,
85
torch::Tensor& targets,
86
torch::Tensor& output,
87
torch::Tensor& loss,
88
at::cuda::CUDAGraph& graph,
89
const short num_warmup_iters = 7) {
90
model.train();
91
92
auto warmupStream = at::cuda::getStreamFromPool();
93
auto captureStream = at::cuda::getStreamFromPool();
94
auto legacyStream = at::cuda::getCurrentCUDAStream();
95
96
at::cuda::setCurrentCUDAStream(warmupStream);
97
98
stream_sync(legacyStream, warmupStream);
99
100
for (C10_UNUSED const auto iter : c10::irange(num_warmup_iters)) {
101
training_step(model, optimizer, data, targets, output, loss);
102
}
103
104
stream_sync(warmupStream, captureStream);
105
at::cuda::setCurrentCUDAStream(captureStream);
106
107
graph.capture_begin();
108
training_step(model, optimizer, data, targets, output, loss);
109
graph.capture_end();
110
111
stream_sync(captureStream, legacyStream);
112
at::cuda::setCurrentCUDAStream(legacyStream);
113
}
114
115
template <typename DataLoader>
116
void train(
117
size_t epoch,
118
Net& model,
119
torch::Device device,
120
DataLoader& data_loader,
121
torch::optim::Optimizer& optimizer,
122
size_t dataset_size,
123
torch::Tensor& data,
124
torch::Tensor& targets,
125
torch::Tensor& output,
126
torch::Tensor& loss,
127
at::cuda::CUDAGraph& graph,
128
bool use_graph) {
129
model.train();
130
131
size_t batch_idx = 0;
132
133
for (const auto& batch : data_loader) {
134
if (batch.data.size(0) != kTrainBatchSize ||
135
batch.target.size(0) != kTrainBatchSize) {
136
continue;
137
}
138
139
data.copy_(batch.data);
140
targets.copy_(batch.target);
141
142
if (use_graph) {
143
graph.replay();
144
} else {
145
training_step(model, optimizer, data, targets, output, loss);
146
}
147
148
if (batch_idx++ % kLogInterval == 0) {
149
float train_loss = loss.item<float>();
150
std::cout << "\rTrain Epoch:" << epoch << " ["
151
<< batch_idx * batch.data.size(0) << "/" << dataset_size
152
<< "] Loss: " << train_loss;
153
}
154
}
155
}
156
157
void test_step(
158
Net& model,
159
torch::Tensor& data,
160
torch::Tensor& targets,
161
torch::Tensor& output,
162
torch::Tensor& loss) {
163
output = model.forward(data);
164
loss = torch::nll_loss(output, targets, {}, torch::Reduction::Sum);
165
}
166
167
void capture_test_graph(
168
Net& model,
169
torch::Tensor& data,
170
torch::Tensor& targets,
171
torch::Tensor& output,
172
torch::Tensor& loss,
173
torch::Tensor& total_loss,
174
torch::Tensor& total_correct,
175
at::cuda::CUDAGraph& graph,
176
const int num_warmup_iters = 7) {
177
torch::NoGradGuard no_grad;
178
model.eval();
179
180
auto warmupStream = at::cuda::getStreamFromPool();
181
auto captureStream = at::cuda::getStreamFromPool();
182
auto legacyStream = at::cuda::getCurrentCUDAStream();
183
184
at::cuda::setCurrentCUDAStream(warmupStream);
185
stream_sync(captureStream, legacyStream);
186
187
for (C10_UNUSED const auto iter : c10::irange(num_warmup_iters)) {
188
test_step(model, data, targets, output, loss);
189
total_loss += loss;
190
total_correct += output.argmax(1).eq(targets).sum();
191
}
192
193
stream_sync(warmupStream, captureStream);
194
at::cuda::setCurrentCUDAStream(captureStream);
195
196
graph.capture_begin();
197
test_step(model, data, targets, output, loss);
198
graph.capture_end();
199
200
stream_sync(captureStream, legacyStream);
201
at::cuda::setCurrentCUDAStream(legacyStream);
202
}
203
204
template <typename DataLoader>
205
void test(
206
Net& model,
207
torch::Device device,
208
DataLoader& data_loader,
209
size_t dataset_size,
210
torch::Tensor& data,
211
torch::Tensor& targets,
212
torch::Tensor& output,
213
torch::Tensor& loss,
214
torch::Tensor& total_loss,
215
torch::Tensor& total_correct,
216
at::cuda::CUDAGraph& graph,
217
bool use_graph) {
218
torch::NoGradGuard no_grad;
219
220
model.eval();
221
loss.zero_();
222
total_loss.zero_();
223
total_correct.zero_();
224
225
for (const auto& batch : data_loader) {
226
if (batch.data.size(0) != kTestBatchSize ||
227
batch.target.size(0) != kTestBatchSize) {
228
continue;
229
}
230
data.copy_(batch.data);
231
targets.copy_(batch.target);
232
233
if (use_graph) {
234
graph.replay();
235
} else {
236
test_step(model, data, targets, output, loss);
237
}
238
total_loss += loss;
239
total_correct += output.argmax(1).eq(targets).sum();
240
}
241
242
float test_loss = total_loss.item<float>() / dataset_size;
243
float test_accuracy =
244
static_cast<float>(total_correct.item<int64_t>()) / dataset_size;
245
246
std::cout << std::endl
247
<< "Test set: Average loss: " << test_loss
248
<< " | Accuracy: " << test_accuracy << std::endl;
249
}
250
251
int main(int argc, char* argv[]) {
252
if (!torch::cuda::is_available()) {
253
std::cout << "CUDA is not available!" << std::endl;
254
return -1;
255
}
256
257
bool use_train_graph = false;
258
bool use_test_graph = false;
259
260
std::vector<std::string> arguments(argv + 1, argv + argc);
261
for (std::string& arg : arguments) {
262
if (arg == "--use-train-graph") {
263
std::cout << "Using CUDA Graph for training." << std::endl;
264
use_train_graph = true;
265
}
266
if (arg == "--use-test-graph") {
267
std::cout << "Using CUDA Graph for testing." << std::endl;
268
use_test_graph = true;
269
}
270
}
271
272
torch::manual_seed(1);
273
torch::cuda::manual_seed(1);
274
torch::Device device(torch::kCUDA);
275
276
Net model;
277
model.to(device);
278
279
auto train_dataset =
280
torch::data::datasets::MNIST(kDataRoot)
281
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
282
.map(torch::data::transforms::Stack<>());
283
const size_t train_dataset_size = train_dataset.size().value();
284
auto train_loader =
285
torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
286
std::move(train_dataset), kTrainBatchSize);
287
288
auto test_dataset =
289
torch::data::datasets::MNIST(
290
kDataRoot, torch::data::datasets::MNIST::Mode::kTest)
291
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
292
.map(torch::data::transforms::Stack<>());
293
const size_t test_dataset_size = test_dataset.size().value();
294
auto test_loader =
295
torch::data::make_data_loader(std::move(test_dataset), kTestBatchSize);
296
297
torch::optim::SGD optimizer(
298
model.parameters(), torch::optim::SGDOptions(0.01).momentum(0.5));
299
300
torch::TensorOptions FloatCUDA =
301
torch::TensorOptions(device).dtype(torch::kFloat);
302
torch::TensorOptions LongCUDA =
303
torch::TensorOptions(device).dtype(torch::kLong);
304
305
torch::Tensor train_data =
306
torch::zeros({kTrainBatchSize, 1, 28, 28}, FloatCUDA);
307
torch::Tensor train_targets = torch::zeros({kTrainBatchSize}, LongCUDA);
308
torch::Tensor train_output = torch::zeros({1}, FloatCUDA);
309
torch::Tensor train_loss = torch::zeros({1}, FloatCUDA);
310
311
torch::Tensor test_data =
312
torch::zeros({kTestBatchSize, 1, 28, 28}, FloatCUDA);
313
torch::Tensor test_targets = torch::zeros({kTestBatchSize}, LongCUDA);
314
torch::Tensor test_output = torch::zeros({1}, FloatCUDA);
315
torch::Tensor test_loss = torch::zeros({1}, FloatCUDA);
316
torch::Tensor test_total_loss = torch::zeros({1}, FloatCUDA);
317
torch::Tensor test_total_correct = torch::zeros({1}, LongCUDA);
318
319
at::cuda::CUDAGraph train_graph;
320
at::cuda::CUDAGraph test_graph;
321
322
capture_train_graph(
323
model,
324
optimizer,
325
train_data,
326
train_targets,
327
train_output,
328
train_loss,
329
train_graph);
330
331
capture_test_graph(
332
model,
333
test_data,
334
test_targets,
335
test_output,
336
test_loss,
337
test_total_loss,
338
test_total_correct,
339
test_graph);
340
341
for (size_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
342
train(
343
epoch,
344
model,
345
device,
346
*train_loader,
347
optimizer,
348
train_dataset_size,
349
train_data,
350
train_targets,
351
train_output,
352
train_loss,
353
train_graph,
354
use_train_graph);
355
test(
356
model,
357
device,
358
*test_loader,
359
test_dataset_size,
360
test_data,
361
test_targets,
362
test_output,
363
test_loss,
364
test_total_loss,
365
test_total_correct,
366
test_graph,
367
use_test_graph);
368
}
369
370
std::cout << " Training/testing complete" << std::endl;
371
return 0;
372
}
373
374