torch_tensor_get_gradient Subroutine

public subroutine torch_tensor_get_gradient(tensor, gradient)

Retrieves the gradient with respect to a Torch Tensor.

Arguments

Type IntentOptional Attributes Name
class(torch_tensor), intent(in) :: tensor

Tensor to compute the gradient with respect to

type(torch_tensor), intent(inout) :: gradient

Tensor holding the gradient


Source Code

  subroutine torch_tensor_get_gradient(tensor, gradient)
    class(torch_tensor), intent(in) :: tensor      !! Tensor to compute the gradient with respect to
    type(torch_tensor), intent(inout) :: gradient  !! Tensor holding the gradient

    interface
      subroutine torch_tensor_get_gradient_c(tensor_c, gradient_c) &
          bind(c, name = 'torch_tensor_get_gradient')
        use, intrinsic :: iso_c_binding, only : c_ptr
        implicit none
        type(c_ptr), value, intent(in) :: tensor_c
        type(c_ptr), value, intent(in) :: gradient_c
      end subroutine torch_tensor_get_gradient_c
    end interface

    if (.not. c_associated(gradient%p)) then
      write(*,*) "Error :: tensors for holding gradients must be constructed before retrieving values"
      stop 1
    end if
    call torch_tensor_get_gradient_c(tensor%p, gradient%p)
  end subroutine torch_tensor_get_gradient