get_shape Function

public function get_shape(self) result(sizes)

Determines the shape of a tensor.

Type Bound

torch_tensor

Arguments

Type IntentOptional Attributes Name
class(torch_tensor), intent(in) :: self

Return Value integer(kind=c_long), pointer, (:)

Pointer to tensor data


Source Code

  function get_shape(self) result(sizes)
    use, intrinsic :: iso_c_binding, only : c_int, c_long, c_ptr
    class(torch_tensor), intent(in) :: self
    integer(kind=c_long), pointer :: sizes(:) !! Pointer to tensor data
    integer(kind=int32) :: ndims(1)
    type(c_ptr) :: cptr

    interface
      function torch_tensor_get_sizes_c(tensor) result(sizes) &
          bind(c, name = 'torch_tensor_get_sizes')
        use, intrinsic :: iso_c_binding, only : c_int, c_long, c_ptr
        type(c_ptr), value, intent(in) :: tensor
        type(c_ptr) :: sizes
      end function torch_tensor_get_sizes_c
    end interface

    ndims(1) = self%get_rank()
    cptr = torch_tensor_get_sizes_c(self%p)
    call c_f_pointer(cptr, sizes, ndims)
  end function get_shape