ctorch.cpp Source File


Source Code

/*
 * For more details on the Torch Tensor C++ API, we refer to the Torch C++ documentation
 * (https://pytorch.org/cppdocs) and more specifically the C++ API documentation
 * (https://pytorch.org/cppdocs/api/library_root.html) pages on the PyTorch website.
 */
#include <torch/script.h>
#include <torch/torch.h>

#include "ctorch.h"

#ifndef GPU_DEVICE
#define GPU_DEVICE GPU_DEVICE_NONE
#endif

// =============================================================================
// --- Functions to aid in consistent error handling
// =============================================================================

// Accept a string message and handle as error. Accepts a cleanup function if desired.
void ctorch_error(const std::string &message,
                  const std::function<void()> &cleanup = nullptr) {
  std::cerr << "[ERROR]: " << message << std::endl;
  if (cleanup) {
    cleanup(); // Perform cleanup actions
  }
  exit(EXIT_FAILURE);
}

// Accept a string message and handle as a warning.
void ctorch_warn(const std::string &message) {
  std::cerr << "[WARNING]: " << message << std::endl;
}

// =============================================================================
// --- Constant expressions
// =============================================================================

// Mapping from FTorch device_data_t to libtorch Dtype
constexpr auto get_libtorch_dtype(torch_data_t dtype) {
  switch (dtype) {
  case torch_kUInt8:
    ctorch_error("uint8 not supported in Fortran");
    // See https://gcc.gnu.org/onlinedocs/gfortran/ISO_005fFORTRAN_005fENV.html
  case torch_kInt8:
    return torch::kInt8;
  case torch_kInt16:
    return torch::kInt16;
  case torch_kInt32:
    return torch::kInt32;
  case torch_kInt64:
    return torch::kInt64;
  case torch_kFloat16:
    ctorch_error("float16 not supported in Fortran");
    // See https://gcc.gnu.org/onlinedocs/gfortran/ISO_005fFORTRAN_005fENV.html
  case torch_kFloat32:
    return torch::kFloat32;
  case torch_kFloat64:
    return torch::kFloat64;
  default:
    ctorch_warn("unknown data type, setting to torch_kFloat32");
    return torch::kFloat32;
  }
}

// Mapping from libtorch Dtype to FTorch device_data_t
torch_data_t get_ftorch_dtype(caffe2::TypeMeta dtype) {
  if (dtype == torch::kUInt8) {
    ctorch_error("uint8 not supported in Fortran");
    // See https://gcc.gnu.org/onlinedocs/gfortran/ISO_005fFORTRAN_005fENV.html
  } else if (dtype == torch::kInt8) {
    return torch_kInt8;
  } else if (dtype == torch::kInt16) {
    return torch_kInt16;
  } else if (dtype == torch::kInt32) {
    return torch_kInt32;
  } else if (dtype == torch::kInt64) {
    return torch_kInt64;
  } else if (dtype == torch::kFloat16) {
    ctorch_error("float16 not supported in Fortran");
    // See https://gcc.gnu.org/onlinedocs/gfortran/ISO_005fFORTRAN_005fENV.html
  } else if (dtype == torch::kFloat32) {
    return torch_kFloat32;
  } else if (dtype == torch::kFloat64) {
    return torch_kFloat64;
  } else {
    std::cerr << "[ERROR]: data type " << dtype << " not supported in Fortran"
              << std::endl;
    exit(EXIT_FAILURE);
  }
  return torch_kFloat32; // Dummy return to satisfy the compiler
}

// Mapping from FTorch device_type_t to libtorch DeviceType
const auto get_libtorch_device(torch_device_t device_type, int device_index) {
  switch (device_type) {
  case torch_kCPU:
    if (device_index != -1) {
      ctorch_warn("device index unused for CPU-only runs");
    }
    return torch::Device(torch::kCPU);
#if GPU_DEVICE == GPU_DEVICE_CUDA
  case torch_kCUDA:
    if (device_index == -1) {
      ctorch_warn("device index unset, defaulting to 0");
      device_index = 0;
    }
    if (device_index >= 0 && device_index < torch::cuda::device_count()) {
      return torch::Device(torch::kCUDA, device_index);
    } else {
      std::cerr << "[ERROR]: invalid device index " << device_index
                << " for device count " << torch::cuda::device_count() << std::endl;
      exit(EXIT_FAILURE);
    }
#endif
  case torch_kMPS:
    if (device_index != -1 && device_index != 0) {
      ctorch_warn("Only one device is available for MPS runs");
    }
    return torch::Device(torch::kMPS);
#if GPU_DEVICE == GPU_DEVICE_XPU
  case torch_kXPU:
    if (device_index == -1) {
      ctorch_warn("device index unset, defaulting to 0");
      device_index = 0;
    }
    if (device_index >= 0 && device_index < torch::xpu::device_count()) {
      return torch::Device(torch::kXPU, device_index);
    } else {
      std::cerr << "[ERROR]: invalid device index " << device_index
                << " for XPU device count " << torch::xpu::device_count() << std::endl;
      exit(EXIT_FAILURE);
    }
#endif
  default:
    ctorch_warn("unknown device type, setting to torch_kCPU");
    return torch::Device(torch::kCPU);
  }
}

// Mapping from libtorch DeviceType to FTorch device_type_t
const torch_device_t get_ftorch_device(torch::DeviceType device_type) {
  switch (device_type) {
  case torch::kCPU:
    return torch_kCPU;
  case torch::kCUDA:
    return torch_kCUDA;
  case torch::kXPU:
    return torch_kXPU;
  case torch::kMPS:
    return torch_kMPS;
  default:
    std::cerr << "[ERROR]: device type " << device_type << " not implemented in FTorch"
              << std::endl;
    exit(EXIT_FAILURE);
  }
}

// =============================================================================
// --- Functions for validating tensors
// =============================================================================

// Check if a tensor is valid
void validate_tensor_not_null(const torch::Tensor *t, const std::string &name) {
  if (!t) {
    throw std::invalid_argument(name + " is null.");
  }
}

// Check if a tensor is defined
void validate_tensor_defined(const torch::Tensor *t, const std::string &name) {
  if (!t->defined()) {
    throw std::invalid_argument(name + " is undefined.");
  }
}

void validate_tensor(const torch::Tensor *t, const std::string &name) {
  validate_tensor_not_null(t, name);
  validate_tensor_defined(t, name);
}

// Check if a tensor has requires_grad set
void validate_requires_grad(const torch::Tensor *t, const std::string &name) {
  if (!t->requires_grad()) {
    throw std::runtime_error(name + " does not have requires_grad set.");
  }
}

void validate_gradient_defined(const torch::Tensor *t, const std::string &name) {
  if (!t->grad().defined()) {
    throw std::runtime_error(
        name + " has an undefined gradient.\nPerhaps you forgot to call backward.");
  }
}

// =============================================================================
// --- Functions for constructing tensors
// =============================================================================

torch_tensor_t torch_empty(int ndim, const int64_t *shape, torch_data_t dtype,
                           torch_device_t device_type, int device_index = -1,
                           const bool requires_grad = false) {
  torch::AutoGradMode enable_grad(requires_grad);
  auto tensor = new torch::Tensor;
  try {
    // This doesn't throw if shape and dimensions are incompatible
    c10::IntArrayRef vshape(shape, ndim);
    auto options = torch::TensorOptions()
                       .dtype(get_libtorch_dtype(dtype))
                       .device(get_libtorch_device(device_type, device_index))
                       .requires_grad(requires_grad);
    *tensor = torch::empty(vshape, options);
  } catch (const torch::Error &e) {
    ctorch_error(e.msg(), [&]() { delete tensor; });
  } catch (const std::exception &e) {
    ctorch_error(e.what(), [&]() { delete tensor; });
  }
  return tensor;
}

torch_tensor_t torch_zeros(int ndim, const int64_t *shape, torch_data_t dtype,
                           torch_device_t device_type, int device_index = -1,
                           const bool requires_grad = false) {
  torch::AutoGradMode enable_grad(requires_grad);
  auto tensor = new torch::Tensor;
  try {
    // This doesn't throw if shape and dimensions are incompatible
    c10::IntArrayRef vshape(shape, ndim);
    auto options = torch::TensorOptions()
                       .dtype(get_libtorch_dtype(dtype))
                       .device(get_libtorch_device(device_type, device_index))
                       .requires_grad(requires_grad);
    *tensor = torch::zeros(vshape, options);
  } catch (const torch::Error &e) {
    ctorch_error(e.msg(), [&]() { delete tensor; });
  } catch (const std::exception &e) {
    ctorch_error(e.what(), [&]() { delete tensor; });
  }
  return tensor;
}

torch_tensor_t torch_ones(int ndim, const int64_t *shape, torch_data_t dtype,
                          torch_device_t device_type, int device_index = -1,
                          const bool requires_grad = false) {
  torch::AutoGradMode enable_grad(requires_grad);
  auto tensor = new torch::Tensor;
  try {
    // This doesn't throw if shape and dimensions are incompatible
    c10::IntArrayRef vshape(shape, ndim);
    auto options = torch::TensorOptions()
                       .dtype(get_libtorch_dtype(dtype))
                       .device(get_libtorch_device(device_type, device_index))
                       .requires_grad(requires_grad);
    *tensor = torch::ones(vshape, options);
  } catch (const torch::Error &e) {
    ctorch_error(e.msg(), [&]() { delete tensor; });
  } catch (const std::exception &e) {
    ctorch_error(e.what(), [&]() { delete tensor; });
  }
  return tensor;
}

// Exposes the given data as a Tensor without taking ownership of the original
// data
torch_tensor_t torch_from_blob(void *data, int ndim, const int64_t *shape,
                               const int64_t *strides, torch_data_t dtype,
                               torch_device_t device_type, int device_index = -1,
                               const bool requires_grad = false) {
  torch::AutoGradMode enable_grad(requires_grad);
  auto tensor = new torch::Tensor;

  try {
    // This doesn't throw if shape and dimensions are incompatible
    c10::IntArrayRef vshape(shape, ndim);
    c10::IntArrayRef vstrides(strides, ndim);
    auto options = torch::TensorOptions()
                       .dtype(get_libtorch_dtype(dtype))
                       .device(get_libtorch_device(device_type, device_index))
                       .requires_grad(requires_grad);
    *tensor = torch::from_blob(data, vshape, vstrides, options);

  } catch (const torch::Error &e) {
    ctorch_error(e.msg(), [&]() { delete tensor; });
  } catch (const std::exception &e) {
    ctorch_error(e.what(), [&]() { delete tensor; });
  }
  return tensor;
}

// =====================================================================================
// --- Functions for interrogating tensors
// =====================================================================================

void torch_tensor_print(const torch_tensor_t tensor) {
  auto t = reinterpret_cast<torch::Tensor *>(tensor);
  std::cout << *t << std::endl;
}

int torch_tensor_get_rank(const torch_tensor_t tensor) {
  auto t = reinterpret_cast<torch::Tensor *>(tensor);
  return t->sizes().size();
}

#ifdef UNIX
const long int *torch_tensor_get_sizes(const torch_tensor_t tensor) {
  auto t = reinterpret_cast<torch::Tensor *>(tensor);
  return t->sizes().data();
}
#else
const long long int *torch_tensor_get_sizes(const torch_tensor_t tensor) {
  auto t = reinterpret_cast<torch::Tensor *>(tensor);
  return t->sizes().data();
}
#endif

torch_data_t torch_tensor_get_dtype(const torch_tensor_t tensor) {
  auto t = reinterpret_cast<torch::Tensor *>(tensor);
  return get_ftorch_dtype(t->dtype());
}

torch_device_t torch_tensor_get_device_type(const torch_tensor_t tensor) {
  auto t = reinterpret_cast<torch::Tensor *>(tensor);
  return get_ftorch_device(t->device().type());
}

int torch_tensor_get_device_index(const torch_tensor_t tensor) {
  auto t = reinterpret_cast<torch::Tensor *>(tensor);
  return t->device().index();
}

bool torch_tensor_requires_grad(const torch_tensor_t tensor) {
  auto t = reinterpret_cast<torch::Tensor *>(tensor);
  return t->requires_grad();
}

// =====================================================================================
// --- Functions for deallocating tensors
// =====================================================================================

void torch_tensor_delete(torch_tensor_t tensor) {
  auto t = reinterpret_cast<torch::Tensor *>(tensor);
  delete t;
}

// =====================================================================================
// --- Functions for manipulating tensors
// =====================================================================================

void torch_tensor_zero(torch_tensor_t tensor) {
  auto t = reinterpret_cast<torch::Tensor *>(tensor);
  validate_tensor(t, "Input tensor");
  t->zero_();
}

void torch_tensor_to(const torch_tensor_t source_tensor, torch_tensor_t target_tensor,
                     bool non_blocking) {
  auto source_tens = reinterpret_cast<torch::Tensor *>(source_tensor);
  auto target_tens = reinterpret_cast<torch::Tensor *>(target_tensor);
  validate_tensor(source_tens, "Source tensor");
  validate_tensor(target_tens, "Target tensor");

  torch::Device device_type = target_tens->device();
  at::ScalarType dtype = target_tens->scalar_type();

  // For non-blocking usage see:
  // https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html
  std::move(*target_tens) = source_tens->to(device_type, dtype, non_blocking);
}

// =====================================================================================
// --- Operator overloads acting on tensors
// =====================================================================================

void torch_tensor_assign(torch_tensor_t output, const torch_tensor_t input) {
  auto out = reinterpret_cast<torch::Tensor *>(output);
  auto in = reinterpret_cast<torch::Tensor *const>(input);
  validate_tensor(out, "Output tensor");
  validate_tensor(in, "Input tensor");
  torch::AutoGradMode enable_grad(in->requires_grad());
  // NOTE: The following line ensures that the output tensor continues to point to a
  //       Fortran array if it was set up to do so using torch_tensor_from_array. If
  //       it's removed then the Fortran array keeps its original value and is no
  //       longer be pointed to.
  std::move(*out) = *in;
  // NOTE: The following line ensures that we always overwrite the requires_grad
  // property matching the PyTorch behaviour. See the Python examples on
  // https://github.com/Cambridge-ICCS/FTorch/pull/373.
  out->requires_grad_(in->requires_grad());
}

void torch_tensor_add(torch_tensor_t output, const torch_tensor_t tensor1,
                      const torch_tensor_t tensor2) {
  auto out = reinterpret_cast<torch::Tensor *>(output);
  auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1);
  auto t2 = reinterpret_cast<torch::Tensor *const>(tensor2);
  *out = *t1 + *t2;
}

void torch_tensor_negative(torch_tensor_t output, const torch_tensor_t tensor) {
  auto out = reinterpret_cast<torch::Tensor *>(output);
  auto t = reinterpret_cast<torch::Tensor *const>(tensor);
  *out = -*t;
}

void torch_tensor_subtract(torch_tensor_t output, const torch_tensor_t tensor1,
                           const torch_tensor_t tensor2) {
  auto out = reinterpret_cast<torch::Tensor *>(output);
  auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1);
  auto t2 = reinterpret_cast<torch::Tensor *const>(tensor2);
  *out = *t1 - *t2;
}

void torch_tensor_multiply(torch_tensor_t output, const torch_tensor_t tensor1,
                           const torch_tensor_t tensor2) {
  auto out = reinterpret_cast<torch::Tensor *>(output);
  auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1);
  auto t2 = reinterpret_cast<torch::Tensor *const>(tensor2);
  *out = *t1 * *t2;
}

void torch_tensor_divide(torch_tensor_t output, const torch_tensor_t tensor1,
                         const torch_tensor_t tensor2) {
  auto out = reinterpret_cast<torch::Tensor *>(output);
  auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1);
  auto t2 = reinterpret_cast<torch::Tensor *const>(tensor2);
  *out = *t1 / *t2;
}

void torch_tensor_power_int(torch_tensor_t output, const torch_tensor_t tensor,
                            const torch_int_t exponent) {
  // NOTE: The following cast will only work for integer exponents
  auto out = reinterpret_cast<torch::Tensor *>(output);
  auto t = reinterpret_cast<torch::Tensor *const>(tensor);
  auto exp = reinterpret_cast<int *const>(exponent);
  *out = pow(*t, *exp);
}

void torch_tensor_power_float(torch_tensor_t output, const torch_tensor_t tensor,
                              const torch_float_t exponent) {
  // NOTE: The following cast will only work for floating point exponents
  auto out = reinterpret_cast<torch::Tensor *>(output);
  auto t = reinterpret_cast<torch::Tensor *const>(tensor);
  auto exp = reinterpret_cast<float *const>(exponent);
  *out = pow(*t, *exp);
}

// ============================================================================
// --- Other operators for computations involving tensors
// ============================================================================

void torch_tensor_sum(torch_tensor_t output, const torch_tensor_t tensor) {
  auto out = reinterpret_cast<torch::Tensor *>(output);
  auto t = reinterpret_cast<torch::Tensor *const>(tensor);

  if (torch_tensor_get_rank(output) != 1) {
    std::stringstream errmsg;
    errmsg << "Invalid rank of output tensor for sum\nrank="
           << torch_tensor_get_rank(output) << " != 1";
    ctorch_error(errmsg.str());
  }
  if (torch_tensor_get_sizes(output)[0] != 1) {
    std::stringstream errmsg;
    errmsg << "Invalid shape of output tensor for sum\nshape=["
           << torch_tensor_get_sizes(output)[0] << "] != [1]";
    ctorch_error(errmsg.str());
  }
  std::move(*out) = t->sum();
}

void torch_tensor_mean(torch_tensor_t output, const torch_tensor_t tensor) {
  auto out = reinterpret_cast<torch::Tensor *>(output);
  auto t = reinterpret_cast<torch::Tensor *const>(tensor);

  if (torch_tensor_get_rank(output) != 1) {
    std::stringstream errmsg;
    std::cerr << "Invalid rank of output tensor for mean\nrank="
              << torch_tensor_get_rank(output) << " != 1";
    ctorch_error(errmsg.str());
  }
  if (torch_tensor_get_sizes(output)[0] != 1) {
    std::stringstream errmsg;
    errmsg << "Invalid shape of output tensor for mean\nshape=["
           << torch_tensor_get_sizes(output)[0] << "] != [1]";
    ctorch_error(errmsg.str());
  }
  std::move(*out) = t->mean();
}

// =============================================================================
// --- Functions related to automatic differentiation functionality for tensors
// =============================================================================

void torch_tensor_zero_grad(torch_tensor_t tensor) {
  auto t = reinterpret_cast<torch::Tensor *>(tensor);
  validate_tensor(t, "Gradient to zero");
  t->mutable_grad().zero_();
}

void torch_tensor_backward(const torch_tensor_t tensor,
                           const torch_tensor_t external_gradient,
                           const bool retain_graph) {
  auto t = reinterpret_cast<torch::Tensor *>(tensor);
  auto g = reinterpret_cast<torch::Tensor *const>(external_gradient);

  try {
    // Check if the tensors are valid and defined
    validate_tensor(t, "Input tensor");
    validate_tensor(g, "External gradient");

    // Perform backwards step
    t->backward(*g, retain_graph);
  } catch (const std::exception &e) {
    ctorch_error(std::string(e.what()) + " in torch_tensor_backward");
  }
}

void torch_tensor_get_gradient(const torch_tensor_t tensor, torch_tensor_t gradient) {
  try {
    // Cast the input pointers to torch::Tensor
    auto t = reinterpret_cast<torch::Tensor *const>(tensor);
    auto g = reinterpret_cast<torch::Tensor *>(gradient);

    // Check if the tensors are valid and defined
    validate_tensor(t, "Input tensor");
    validate_tensor_not_null(g, "Output gradient");
    // Check input has requires_grad set and can generate a valid gradient tensor
    validate_requires_grad(t, "Input tensor");
    validate_gradient_defined(t, "Input tensor");

    // Assign the gradient to the output tensor
    std::move(*g) = t->grad();
  } catch (const std::exception &e) {
    ctorch_error(std::string(e.what()) + " in torch_tensor_get_gradient");
  }
}

// =============================================================================
// --- Torch model API
// =============================================================================

void set_is_training(torch_jit_script_module_t module, const bool is_training = false) {
  auto model = static_cast<torch::jit::script::Module *>(module);
  if (is_training) {
    model->train();
  } else {
    model->eval();
  }
}

torch_jit_script_module_t torch_jit_load(const char *filename,
                                         const torch_device_t device_type = torch_kCPU,
                                         const int device_index = -1,
                                         const bool requires_grad = false,
                                         const bool is_training = false) {
  torch::AutoGradMode enable_grad(requires_grad);
  torch::jit::script::Module *module = nullptr;
  try {
    module = new torch::jit::script::Module;
    *module =
        torch::jit::load(filename, get_libtorch_device(device_type, device_index));
  } catch (const torch::Error &e) {
    ctorch_error(e.msg(), [&]() { delete module; });
  } catch (const std::exception &e) {
    ctorch_error(e.what(), [&]() { delete module; });
  }
  set_is_training(module, is_training);

  return module;
}

void torch_jit_module_forward(const torch_jit_script_module_t module,
                              const torch_tensor_t *inputs, const int nin,
                              torch_tensor_t *outputs, const int nout,
                              const bool requires_grad = false) {
  torch::AutoGradMode enable_grad(requires_grad);
  // Here we cast the pointers we recieved in to Tensor objects
  auto model = static_cast<torch::jit::script::Module *>(module);
  auto in = reinterpret_cast<torch::Tensor *const *>(inputs);
  auto out = reinterpret_cast<torch::Tensor **>(outputs);
  // Local IValue for checking we are passed types
  torch::jit::IValue LocalTensor;
  // Generate a vector of IValues (placeholders for various Torch types)
  std::vector<torch::jit::IValue> inputs_vec;
  // Populate with Tensors pointed at by pointers
  // For each IValue check it is of Tensor type
  for (int i = 0; i < nin; ++i) {
    LocalTensor = *(in[i]);
    if (LocalTensor.isTensor()) {
      inputs_vec.push_back(LocalTensor);
    } else {
      ctorch_error("One of the inputs to torch_jit_module_forward is not a Tensor");
    }
  }
  try {
    auto model_out = model->forward(inputs_vec);
    if (model_out.isTensor()) {
      // Single output models will return a tensor directly.
      std::move(*out[0]) = model_out.toTensor();
    } else if (model_out.isTuple()) {
      // Multiple output models will return a tuple => cast to tensors.
      for (int i = 0; i < nout; ++i) {
        std::move(*out[i]) = model_out.toTuple()->elements()[i].toTensor();
      }
    } else {
      // If for some reason the forward method does not return a Tensor it
      // should raise an error when trying to cast to a Tensor type
      ctorch_error("Model Output is neither Tensor nor Tuple");
    }
  } catch (const torch::Error &e) {
    ctorch_error(e.msg());
  } catch (const std::exception &e) {
    ctorch_error(e.what());
  }
}

void torch_jit_module_delete(torch_jit_script_module_t module) {
  auto m = reinterpret_cast<torch::jit::script::Module *>(module);
  delete m;
}