/* * For more details on the Torch Tensor C++ API, we refer to the Torch C++ documentation * (https://pytorch.org/cppdocs) and more specifically the C++ API documentation * (https://pytorch.org/cppdocs/api/library_root.html) pages on the PyTorch website. */ #include <torch/script.h> #include <torch/torch.h> #include "ctorch.h" // ============================================================================= // --- Constant expressions // ============================================================================= // Mapping from FTorch device_data_t to libtorch Dtype constexpr auto get_libtorch_dtype(torch_data_t dtype) { switch (dtype) { case torch_kUInt8: std::cerr << "[WARNING]: uint8 not supported in Fortran" << std::endl; // See https://gcc.gnu.org/onlinedocs/gfortran/ISO_005fFORTRAN_005fENV.html exit(EXIT_FAILURE); case torch_kInt8: return torch::kInt8; case torch_kInt16: return torch::kInt16; case torch_kInt32: return torch::kInt32; case torch_kInt64: return torch::kInt64; case torch_kFloat16: std::cerr << "[WARNING]: float16 not supported in Fortran" << std::endl; // See https://gcc.gnu.org/onlinedocs/gfortran/ISO_005fFORTRAN_005fENV.html exit(EXIT_FAILURE); case torch_kFloat32: return torch::kFloat32; case torch_kFloat64: return torch::kFloat64; default: std::cerr << "[WARNING]: unknown data type, setting to torch_kFloat32" << std::endl; return torch::kFloat32; } } // Mapping from libtorch Dtype to FTorch device_data_t torch_data_t get_ftorch_dtype(caffe2::TypeMeta dtype) { if (dtype == torch::kUInt8) { std::cerr << "[WARNING]: uint8 not supported in Fortran" << std::endl; // See https://gcc.gnu.org/onlinedocs/gfortran/ISO_005fFORTRAN_005fENV.html exit(EXIT_FAILURE); } else if (dtype == torch::kInt8) { return torch_kInt8; } else if (dtype == torch::kInt16) { return torch_kInt16; } else if (dtype == torch::kInt32) { return torch_kInt32; } else if (dtype == torch::kInt64) { return torch_kInt64; } else if (dtype == torch::kFloat16) { std::cerr << "[WARNING]: float16 not supported in Fortran" << std::endl; // See https://gcc.gnu.org/onlinedocs/gfortran/ISO_005fFORTRAN_005fENV.html exit(EXIT_FAILURE); } else if (dtype == torch::kFloat32) { return torch_kFloat32; } else if (dtype == torch::kFloat64) { return torch_kFloat64; } else { std::cerr << "[ERROR]: data type " << dtype << " not supported in Fortran" << std::endl; exit(EXIT_FAILURE); } } // Mapping from FTorch device_type_t to libtorch DeviceType const auto get_libtorch_device(torch_device_t device_type, int device_index) { switch (device_type) { case torch_kCPU: if (device_index != -1) { std::cerr << "[WARNING]: device index unused for CPU-only runs" << std::endl; } return torch::Device(torch::kCPU); case torch_kCUDA: if (device_index == -1) { std::cerr << "[WARNING]: device index unset, defaulting to 0" << std::endl; device_index = 0; } if (device_index >= 0 && device_index < torch::cuda::device_count()) { return torch::Device(torch::kCUDA, device_index); } else { std::cerr << "[ERROR]: invalid device index " << device_index << " for device count " << torch::cuda::device_count() << std::endl; exit(EXIT_FAILURE); } default: std::cerr << "[WARNING]: unknown device type, setting to torch_kCPU" << std::endl; return torch::Device(torch::kCPU); } } // Mapping from libtorch DeviceType to FTorch device_type_t const torch_device_t get_ftorch_device(torch::DeviceType device_type) { switch (device_type) { case torch::kCPU: return torch_kCPU; case torch::kCUDA: return torch_kCUDA; default: std::cerr << "[ERROR]: device type " << device_type << " not implemented in FTorch" << std::endl; exit(EXIT_FAILURE); } } // ============================================================================= // --- Functions for constructing tensors // ============================================================================= torch_tensor_t torch_empty(int ndim, const int64_t *shape, torch_data_t dtype, torch_device_t device_type, int device_index = -1, const bool requires_grad = false) { torch::AutoGradMode enable_grad(requires_grad); torch::Tensor *tensor = nullptr; try { // This doesn't throw if shape and dimensions are incompatible c10::IntArrayRef vshape(shape, ndim); tensor = new torch::Tensor; *tensor = torch::empty(vshape, torch::dtype(get_libtorch_dtype(dtype))) .to(get_libtorch_device(device_type, device_index)); } catch (const torch::Error &e) { std::cerr << "[ERROR]: " << e.msg() << std::endl; delete tensor; exit(EXIT_FAILURE); } catch (const std::exception &e) { std::cerr << "[ERROR]: " << e.what() << std::endl; delete tensor; exit(EXIT_FAILURE); } return tensor; } torch_tensor_t torch_zeros(int ndim, const int64_t *shape, torch_data_t dtype, torch_device_t device_type, int device_index = -1, const bool requires_grad = false) { torch::AutoGradMode enable_grad(requires_grad); torch::Tensor *tensor = nullptr; try { // This doesn't throw if shape and dimensions are incompatible c10::IntArrayRef vshape(shape, ndim); tensor = new torch::Tensor; *tensor = torch::zeros(vshape, torch::dtype(get_libtorch_dtype(dtype))) .to(get_libtorch_device(device_type, device_index)); } catch (const torch::Error &e) { std::cerr << "[ERROR]: " << e.msg() << std::endl; delete tensor; exit(EXIT_FAILURE); } catch (const std::exception &e) { std::cerr << "[ERROR]: " << e.what() << std::endl; delete tensor; exit(EXIT_FAILURE); } return tensor; } torch_tensor_t torch_ones(int ndim, const int64_t *shape, torch_data_t dtype, torch_device_t device_type, int device_index = -1, const bool requires_grad = false) { torch::AutoGradMode enable_grad(requires_grad); torch::Tensor *tensor = nullptr; try { // This doesn't throw if shape and dimensions are incompatible c10::IntArrayRef vshape(shape, ndim); tensor = new torch::Tensor; *tensor = torch::ones(vshape, torch::dtype(get_libtorch_dtype(dtype))) .to(get_libtorch_device(device_type, device_index)); } catch (const torch::Error &e) { std::cerr << "[ERROR]: " << e.msg() << std::endl; delete tensor; exit(EXIT_FAILURE); } catch (const std::exception &e) { std::cerr << "[ERROR]: " << e.what() << std::endl; delete tensor; exit(EXIT_FAILURE); } return tensor; } // Exposes the given data as a Tensor without taking ownership of the original // data torch_tensor_t torch_from_blob(void *data, int ndim, const int64_t *shape, const int64_t *strides, torch_data_t dtype, torch_device_t device_type, int device_index = -1, const bool requires_grad = false) { torch::AutoGradMode enable_grad(requires_grad); torch::Tensor *tensor = nullptr; try { // This doesn't throw if shape and dimensions are incompatible c10::IntArrayRef vshape(shape, ndim); c10::IntArrayRef vstrides(strides, ndim); tensor = new torch::Tensor; *tensor = torch::from_blob(data, vshape, vstrides, torch::dtype(get_libtorch_dtype(dtype))) .to(get_libtorch_device(device_type, device_index)); } catch (const torch::Error &e) { std::cerr << "[ERROR]: " << e.msg() << std::endl; delete tensor; exit(EXIT_FAILURE); } catch (const std::exception &e) { std::cerr << "[ERROR]: " << e.what() << std::endl; delete tensor; exit(EXIT_FAILURE); } return tensor; } // ===================================================================================== // --- Functions for interrogating tensors // ===================================================================================== void *torch_to_blob(const torch_tensor_t tensor, const torch_data_t dtype) { auto t = reinterpret_cast<torch::Tensor *const>(tensor); void *raw_ptr; switch (dtype) { case torch_kUInt8: std::cerr << "[WARNING]: uint8 not supported" << std::endl; exit(EXIT_FAILURE); case torch_kInt8: raw_ptr = (void *)t->data_ptr<int8_t>(); break; case torch_kInt16: raw_ptr = (void *)t->data_ptr<int16_t>(); break; case torch_kInt32: raw_ptr = (void *)t->data_ptr<int32_t>(); break; case torch_kInt64: raw_ptr = (void *)t->data_ptr<int64_t>(); break; case torch_kFloat16: std::cerr << "[WARNING]: float16 not supported" << std::endl; // NOTE: std::float16_t is available but only with C++23 exit(EXIT_FAILURE); case torch_kFloat32: raw_ptr = (void *)t->data_ptr<float>(); // NOTE: std::float32_t is available but only with C++23 break; case torch_kFloat64: raw_ptr = (void *)t->data_ptr<double>(); // NOTE: std::float64_t is available but only with C++23 break; default: std::cerr << "[WARNING]: unknown data type" << std::endl; exit(EXIT_FAILURE); } return raw_ptr; } void torch_tensor_print(const torch_tensor_t tensor) { auto t = reinterpret_cast<torch::Tensor *>(tensor); std::cout << *t << std::endl; } int torch_tensor_get_rank(const torch_tensor_t tensor) { auto t = reinterpret_cast<torch::Tensor *>(tensor); return t->sizes().size(); } #ifdef UNIX const long int *torch_tensor_get_sizes(const torch_tensor_t tensor) { auto t = reinterpret_cast<torch::Tensor *>(tensor); return t->sizes().data(); } #else const long long int *torch_tensor_get_sizes(const torch_tensor_t tensor) { auto t = reinterpret_cast<torch::Tensor *>(tensor); return t->sizes().data(); } #endif torch_data_t torch_tensor_get_dtype(const torch_tensor_t tensor) { auto t = reinterpret_cast<torch::Tensor *>(tensor); return get_ftorch_dtype(t->dtype()); } torch_device_t torch_tensor_get_device_type(const torch_tensor_t tensor) { auto t = reinterpret_cast<torch::Tensor *>(tensor); return get_ftorch_device(t->device().type()); } int torch_tensor_get_device_index(const torch_tensor_t tensor) { auto t = reinterpret_cast<torch::Tensor *>(tensor); return t->device().index(); } // ===================================================================================== // --- Functions for deallocating tensors // ===================================================================================== void torch_tensor_delete(torch_tensor_t tensor) { auto t = reinterpret_cast<torch::Tensor *>(tensor); delete t; } // ===================================================================================== // --- Operator overloads acting on tensors // ===================================================================================== torch_tensor_t torch_tensor_assign(const torch_tensor_t input) { auto in = reinterpret_cast<torch::Tensor *const>(input); torch::AutoGradMode enable_grad(in->requires_grad()); torch::Tensor *output = nullptr; output = new torch::Tensor; *output = in->detach().clone(); return output; } torch_tensor_t torch_tensor_add(const torch_tensor_t tensor1, const torch_tensor_t tensor2) { auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1); auto t2 = reinterpret_cast<torch::Tensor *const>(tensor2); torch::Tensor *output = nullptr; output = new torch::Tensor; *output = *t1 + *t2; return output; } torch_tensor_t torch_tensor_negative(const torch_tensor_t tensor) { auto t = reinterpret_cast<torch::Tensor *const>(tensor); torch::Tensor *output = nullptr; output = new torch::Tensor; *output = -*t; return output; } torch_tensor_t torch_tensor_subtract(const torch_tensor_t tensor1, const torch_tensor_t tensor2) { auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1); auto t2 = reinterpret_cast<torch::Tensor *const>(tensor2); torch::Tensor *output = nullptr; output = new torch::Tensor; *output = *t1 - *t2; return output; } torch_tensor_t torch_tensor_multiply(const torch_tensor_t tensor1, const torch_tensor_t tensor2) { auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1); auto t2 = reinterpret_cast<torch::Tensor *const>(tensor2); torch::Tensor *output = nullptr; output = new torch::Tensor; *output = *t1 * *t2; return output; } torch_tensor_t torch_tensor_divide(const torch_tensor_t tensor1, const torch_tensor_t tensor2) { auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1); auto t2 = reinterpret_cast<torch::Tensor *const>(tensor2); torch::Tensor *output = nullptr; output = new torch::Tensor; *output = *t1 / *t2; return output; } torch_tensor_t torch_tensor_power_int(const torch_tensor_t tensor, const torch_int_t exponent) { auto t = reinterpret_cast<torch::Tensor *const>(tensor); // NOTE: The following cast will only work for integer exponents auto exp = reinterpret_cast<int *const>(exponent); torch::Tensor *output = nullptr; output = new torch::Tensor; *output = pow(*t, *exp); return output; } torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor, const torch_float_t exponent) { auto t = reinterpret_cast<torch::Tensor *const>(tensor); // NOTE: The following cast will only work for floating point exponents auto exp = reinterpret_cast<float *const>(exponent); torch::Tensor *output = nullptr; output = new torch::Tensor; *output = pow(*t, *exp); return output; } // ============================================================================= // --- Torch model API // ============================================================================= void set_is_training(torch_jit_script_module_t module, const bool is_training = false) { auto model = static_cast<torch::jit::script::Module *>(module); if (is_training) { model->train(); } else { model->eval(); } } torch_jit_script_module_t torch_jit_load(const char *filename, const torch_device_t device_type = torch_kCPU, const int device_index = -1, const bool requires_grad = false, const bool is_training = false) { torch::AutoGradMode enable_grad(requires_grad); torch::jit::script::Module *module = nullptr; try { module = new torch::jit::script::Module; *module = torch::jit::load(filename, get_libtorch_device(device_type, device_index)); } catch (const torch::Error &e) { std::cerr << "[ERROR]: " << e.msg() << std::endl; delete module; exit(EXIT_FAILURE); } catch (const std::exception &e) { std::cerr << "[ERROR]: " << e.what() << std::endl; delete module; exit(EXIT_FAILURE); } set_is_training(module, is_training); return module; } void torch_jit_module_forward(const torch_jit_script_module_t module, const torch_tensor_t *inputs, const int nin, torch_tensor_t *outputs, const int nout, const bool requires_grad = false) { torch::AutoGradMode enable_grad(requires_grad); // Here we cast the pointers we recieved in to Tensor objects auto model = static_cast<torch::jit::script::Module *>(module); auto in = reinterpret_cast<torch::Tensor *const *>(inputs); auto out = reinterpret_cast<torch::Tensor **>(outputs); // Local IValue for checking we are passed types torch::jit::IValue LocalTensor; // Generate a vector of IValues (placeholders for various Torch types) std::vector<torch::jit::IValue> inputs_vec; // Populate with Tensors pointed at by pointers // For each IValue check it is of Tensor type for (int i = 0; i < nin; ++i) { LocalTensor = *(in[i]); if (LocalTensor.isTensor()) { inputs_vec.push_back(LocalTensor); } else { std::cerr << "[ERROR]: One of the inputs to torch_jit_module_forward is " "not a Tensor." << std::endl; exit(EXIT_FAILURE); } } try { auto model_out = model->forward(inputs_vec); if (model_out.isTensor()) { // Single output models will return a tensor directly. std::move(*out[0]) = model_out.toTensor(); } else if (model_out.isTuple()) { // Multiple output models will return a tuple => cast to tensors. for (int i = 0; i < nout; ++i) { std::move(*out[i]) = model_out.toTuple()->elements()[i].toTensor(); } } else { // If for some reason the forward method does not return a Tensor it // should raise an error when trying to cast to a Tensor type std::cerr << "[ERROR]: Model Output is neither Tensor nor Tuple." << std::endl; } } catch (const torch::Error &e) { std::cerr << "[ERROR]: " << e.msg() << std::endl; exit(EXIT_FAILURE); } catch (const std::exception &e) { std::cerr << "[ERROR]: " << e.what() << std::endl; exit(EXIT_FAILURE); } } void torch_jit_module_delete(torch_jit_script_module_t module) { auto m = reinterpret_cast<torch::jit::script::Module *>(module); delete m; }