torch_tensor_backward Subroutine

public subroutine torch_tensor_backward(tensor, retain_graph)

Performs back-propagation on a Torch Tensor, given some external gradient.

Arguments

Type IntentOptional Attributes Name
type(torch_tensor), intent(in) :: tensor

Tensor to compute gradients of

logical, intent(in), optional :: retain_graph

Should the computational graph be retained?


Source Code

  subroutine torch_tensor_backward(tensor, retain_graph)
    use, intrinsic :: iso_c_binding, only : c_bool
    type(torch_tensor), intent(in) :: tensor       !! Tensor to compute gradients of
    logical, optional, intent(in)  :: retain_graph !! Should the computational graph be retained?

    ! Local arguments
    type(torch_tensor) :: external_gradient   !! External tensor used as an initial scaling of the gradient calculation
    logical(c_bool) :: retain_graph_value

    interface
      subroutine torch_tensor_backward_c(tensor_c, external_gradient_c, retain_graph_c) &
          bind(c, name = 'torch_tensor_backward')
        use, intrinsic :: iso_c_binding, only : c_bool, c_ptr
        implicit none
        type(c_ptr), value, intent(in) :: tensor_c
        type(c_ptr), value, intent(in) :: external_gradient_c
        logical(c_bool), value, intent(in) :: retain_graph_c
      end subroutine torch_tensor_backward_c
    end interface

    ! External gradient to provide to the back-propagation consisting of a tensor of ones
    ! TODO: Accept other external gradients as an optional argument
    call torch_tensor_ones(external_gradient, tensor%get_rank(), tensor%get_shape(), &
                           tensor%get_dtype(), tensor%get_device_type(), &
                           device_index=tensor%get_device_index())

    ! Do not retain the graph by default
    if (present(retain_graph)) then
      retain_graph_value = retain_graph
    else
      retain_graph_value = .false.
    end if

    ! Call back-propagation with the provided external gradient
    call torch_tensor_backward_c(tensor%p, external_gradient%p, retain_graph_value)

    ! Delete the external gradient tensor
    call torch_tensor_delete(external_gradient)
  end subroutine torch_tensor_backward