Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/advanced_source/dispatcher/op.cpp
Views: 713
#include <torch/torch.h>1#include <torch/script.h>23#include <ATen/NamedTensorUtils.h>45using torch::Tensor;6using torch::DeviceType;7using torch::autograd::tensor_list;8using torch::autograd::AutogradContext;910// BEGIN myadd11Tensor myadd(const Tensor& self, const Tensor& other) {12static auto op = torch::Dispatcher::singleton()13.findSchemaOrThrow("myops::myadd", "")14.typed<decltype(myadd)>();15return op.call(self, other);16}17// END myadd1819// BEGIN TORCH_LIBRARY20TORCH_LIBRARY(myops, m) {21m.def("myadd(Tensor self, Tensor other) -> Tensor");22}23// END TORCH_LIBRARY2425// BEGIN myadd_cpu26Tensor myadd_cpu(const Tensor& self_, const Tensor& other_) {27TORCH_CHECK(self_.sizes() == other_.sizes());28TORCH_INTERNAL_ASSERT(self_.device().type() == DeviceType::CPU);29TORCH_INTERNAL_ASSERT(other_.device().type() == DeviceType::CPU);30Tensor self = self_.contiguous();31Tensor other = other_.contiguous();32Tensor result = torch::empty(self.sizes(), self.options());33const float* self_ptr = self.data_ptr<float>();34const float* other_ptr = other.data_ptr<float>();35float* result_ptr = result.data_ptr<float>();36for (int64_t i = 0; i < result.numel(); i++) {37result_ptr[i] = self_ptr[i] + other_ptr[i];38}39return result;40}41// END myadd_cpu4243// BEGIN TORCH_LIBRARY_IMPL CPU44TORCH_LIBRARY_IMPL(myops, CPU, m) {45m.impl("myadd", myadd_cpu);46}47// END TORCH_LIBRARY_IMPL CPU4849Tensor myadd_cuda(const Tensor& self, const Tensor& other) {50// Insert your CUDA implementation here51TORCH_CHECK(0, "CUDA not yet implemented");52}5354// BEGIN TORCH_LIBRARY_IMPL CUDA55TORCH_LIBRARY_IMPL(myops, CUDA, m) {56m.impl("myadd", myadd_cuda);57}58// END TORCH_LIBRARY_IMPL CUDA5960// BEGIN myadd_autograd61class MyAddFunction : public torch::autograd::Function<MyAddFunction> {62public:63static Tensor forward(64AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {65at::AutoNonVariableTypeMode g;66return myadd(self, other);67}6869static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) {70auto grad_output = grad_outputs[0];71return {grad_output, grad_output};72}73};7475Tensor myadd_autograd(const Tensor& self, const Tensor& other) {76return MyAddFunction::apply(self, other)[0];77}78// END myadd_autograd7980// BEGIN TORCH_LIBRARY_IMPL Autograd81TORCH_LIBRARY_IMPL(myops, Autograd, m) {82m.impl("myadd", myadd_autograd);83}84// END TORCH_LIBRARY_IMPL Autograd8586#if 087// BEGIN TORCH_LIBRARY_IMPL Named88Tensor myadd_named(const Tensor& self, const Tensor& other) {89// TODO: shouldn't need to do size check here90TORCH_CHECK(self.sizes() == other.sizes());91auto maybe_outnames = at::unify_from_right(self.names(), other.names());92auto result = ([&]() {93at::NoNamesGuard guard;94return myadd(self, other);95})();96at::namedinference::propagate_names_if_nonempty(result, maybe_outnames);97return result;98}99100TORCH_LIBRARY_IMPL(myops, Named, m) {101m.impl("myadd", myadd_named);102}103// END TORCH_LIBRARY_IMPL Named104#endif105106107