CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/dispatcher/op.cpp
Views: 713
1
#include <torch/torch.h>
2
#include <torch/script.h>
3
4
#include <ATen/NamedTensorUtils.h>
5
6
using torch::Tensor;
7
using torch::DeviceType;
8
using torch::autograd::tensor_list;
9
using torch::autograd::AutogradContext;
10
11
// BEGIN myadd
12
Tensor myadd(const Tensor& self, const Tensor& other) {
13
static auto op = torch::Dispatcher::singleton()
14
.findSchemaOrThrow("myops::myadd", "")
15
.typed<decltype(myadd)>();
16
return op.call(self, other);
17
}
18
// END myadd
19
20
// BEGIN TORCH_LIBRARY
21
TORCH_LIBRARY(myops, m) {
22
m.def("myadd(Tensor self, Tensor other) -> Tensor");
23
}
24
// END TORCH_LIBRARY
25
26
// BEGIN myadd_cpu
27
Tensor myadd_cpu(const Tensor& self_, const Tensor& other_) {
28
TORCH_CHECK(self_.sizes() == other_.sizes());
29
TORCH_INTERNAL_ASSERT(self_.device().type() == DeviceType::CPU);
30
TORCH_INTERNAL_ASSERT(other_.device().type() == DeviceType::CPU);
31
Tensor self = self_.contiguous();
32
Tensor other = other_.contiguous();
33
Tensor result = torch::empty(self.sizes(), self.options());
34
const float* self_ptr = self.data_ptr<float>();
35
const float* other_ptr = other.data_ptr<float>();
36
float* result_ptr = result.data_ptr<float>();
37
for (int64_t i = 0; i < result.numel(); i++) {
38
result_ptr[i] = self_ptr[i] + other_ptr[i];
39
}
40
return result;
41
}
42
// END myadd_cpu
43
44
// BEGIN TORCH_LIBRARY_IMPL CPU
45
TORCH_LIBRARY_IMPL(myops, CPU, m) {
46
m.impl("myadd", myadd_cpu);
47
}
48
// END TORCH_LIBRARY_IMPL CPU
49
50
Tensor myadd_cuda(const Tensor& self, const Tensor& other) {
51
// Insert your CUDA implementation here
52
TORCH_CHECK(0, "CUDA not yet implemented");
53
}
54
55
// BEGIN TORCH_LIBRARY_IMPL CUDA
56
TORCH_LIBRARY_IMPL(myops, CUDA, m) {
57
m.impl("myadd", myadd_cuda);
58
}
59
// END TORCH_LIBRARY_IMPL CUDA
60
61
// BEGIN myadd_autograd
62
class MyAddFunction : public torch::autograd::Function<MyAddFunction> {
63
public:
64
static Tensor forward(
65
AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
66
at::AutoNonVariableTypeMode g;
67
return myadd(self, other);
68
}
69
70
static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) {
71
auto grad_output = grad_outputs[0];
72
return {grad_output, grad_output};
73
}
74
};
75
76
Tensor myadd_autograd(const Tensor& self, const Tensor& other) {
77
return MyAddFunction::apply(self, other)[0];
78
}
79
// END myadd_autograd
80
81
// BEGIN TORCH_LIBRARY_IMPL Autograd
82
TORCH_LIBRARY_IMPL(myops, Autograd, m) {
83
m.impl("myadd", myadd_autograd);
84
}
85
// END TORCH_LIBRARY_IMPL Autograd
86
87
#if 0
88
// BEGIN TORCH_LIBRARY_IMPL Named
89
Tensor myadd_named(const Tensor& self, const Tensor& other) {
90
// TODO: shouldn't need to do size check here
91
TORCH_CHECK(self.sizes() == other.sizes());
92
auto maybe_outnames = at::unify_from_right(self.names(), other.names());
93
auto result = ([&]() {
94
at::NoNamesGuard guard;
95
return myadd(self, other);
96
})();
97
at::namedinference::propagate_names_if_nonempty(result, maybe_outnames);
98
return result;
99
}
100
101
TORCH_LIBRARY_IMPL(myops, Named, m) {
102
m.impl("myadd", myadd_named);
103
}
104
// END TORCH_LIBRARY_IMPL Named
105
#endif
106
107