torch_tensor_backward Subroutine

public subroutine torch_tensor_backward(tensor, retain_graph)

Uses

  • proc~~torch_tensor_backward~~UsesGraph proc~torch_tensor_backward torch_tensor_backward iso_c_binding iso_c_binding proc~torch_tensor_backward->iso_c_binding

Performs back-propagation on a Torch Tensor, given some external gradient.

Arguments

Type IntentOptional Attributes Name
type(torch_tensor), intent(in) :: tensor

Tensor to compute gradients of

logical, intent(in), optional :: retain_graph

Should the computational graph be retained?


Calls

proc~~torch_tensor_backward~~CallsGraph proc~torch_tensor_backward torch_tensor_backward proc~torch_tensor_delete torch_tensor_delete proc~torch_tensor_backward->proc~torch_tensor_delete proc~torch_tensor_get_device_index torch_tensor%torch_tensor_get_device_index proc~torch_tensor_backward->proc~torch_tensor_get_device_index proc~torch_tensor_get_device_type torch_tensor%torch_tensor_get_device_type proc~torch_tensor_backward->proc~torch_tensor_get_device_type proc~torch_tensor_get_dtype torch_tensor%torch_tensor_get_dtype proc~torch_tensor_backward->proc~torch_tensor_get_dtype proc~torch_tensor_get_rank torch_tensor%torch_tensor_get_rank proc~torch_tensor_backward->proc~torch_tensor_get_rank proc~torch_tensor_get_shape torch_tensor%torch_tensor_get_shape proc~torch_tensor_backward->proc~torch_tensor_get_shape proc~torch_tensor_ones torch_tensor_ones proc~torch_tensor_backward->proc~torch_tensor_ones proc~torch_tensor_get_shape->proc~torch_tensor_get_rank

Source Code

  subroutine torch_tensor_backward(tensor, retain_graph)
    use, intrinsic :: iso_c_binding, only : c_bool
    type(torch_tensor), intent(in) :: tensor       !! Tensor to compute gradients of
    logical, optional, intent(in)  :: retain_graph !! Should the computational graph be retained?

    ! Local arguments
    type(torch_tensor) :: external_gradient   !! External tensor used as an initial scaling of the gradient calculation
    logical(c_bool) :: retain_graph_value

    interface
      subroutine torch_tensor_backward_c(tensor_c, external_gradient_c, retain_graph_c) &
          bind(c, name = 'torch_tensor_backward')
        use, intrinsic :: iso_c_binding, only : c_bool, c_ptr
        implicit none
        type(c_ptr), value, intent(in) :: tensor_c
        type(c_ptr), value, intent(in) :: external_gradient_c
        logical(c_bool), value, intent(in) :: retain_graph_c
      end subroutine torch_tensor_backward_c
    end interface

    ! External gradient to provide to the back-propagation consisting of a tensor of ones
    ! TODO: Accept other external gradients as an optional argument
    call torch_tensor_ones(external_gradient, tensor%get_rank(), tensor%get_shape(), &
                           tensor%get_dtype(), tensor%get_device_type(), &
                           device_index=tensor%get_device_index())

    ! Do not retain the graph by default
    if (present(retain_graph)) then
      retain_graph_value = retain_graph
    else
      retain_graph_value = .false.
    end if

    ! Call back-propagation with the provided external gradient
    call torch_tensor_backward_c(tensor%p, external_gradient%p, retain_graph_value)

    ! Delete the external gradient tensor
    call torch_tensor_delete(external_gradient)
  end subroutine torch_tensor_backward