Performs back-propagation on a Torch Tensor, given some external gradient.
Type | Intent | Optional | Attributes | Name | ||
---|---|---|---|---|---|---|
type(torch_tensor), | intent(in) | :: | tensor |
Tensor to compute gradients of |
||
logical, | intent(in), | optional | :: | retain_graph |
Should the computational graph be retained? |
subroutine torch_tensor_backward(tensor, retain_graph) use, intrinsic :: iso_c_binding, only : c_bool type(torch_tensor), intent(in) :: tensor !! Tensor to compute gradients of logical, optional, intent(in) :: retain_graph !! Should the computational graph be retained? ! Local arguments type(torch_tensor) :: external_gradient !! External tensor used as an initial scaling of the gradient calculation logical(c_bool) :: retain_graph_value interface subroutine torch_tensor_backward_c(tensor_c, external_gradient_c, retain_graph_c) & bind(c, name = 'torch_tensor_backward') use, intrinsic :: iso_c_binding, only : c_bool, c_ptr implicit none type(c_ptr), value, intent(in) :: tensor_c type(c_ptr), value, intent(in) :: external_gradient_c logical(c_bool), value, intent(in) :: retain_graph_c end subroutine torch_tensor_backward_c end interface ! External gradient to provide to the back-propagation consisting of a tensor of ones ! TODO: Accept other external gradients as an optional argument call torch_tensor_ones(external_gradient, tensor%get_rank(), tensor%get_shape(), & tensor%get_dtype(), tensor%get_device_type(), & device_index=tensor%get_device_index()) ! Do not retain the graph by default if (present(retain_graph)) then retain_graph_value = retain_graph else retain_graph_value = .false. end if ! Call back-propagation with the provided external gradient call torch_tensor_backward_c(tensor%p, external_gradient%p, retain_graph_value) ! Delete the external gradient tensor call torch_tensor_delete(external_gradient) end subroutine torch_tensor_backward