torch_tensor_to Subroutine

public subroutine torch_tensor_to(source_tensor, target_tensor, non_blocking)

Moves a source_tensor tensor to a target tensor's device and dtype

Arguments

Type IntentOptional Attributes Name
type(torch_tensor), intent(in) :: source_tensor

Source tensor to be moved

type(torch_tensor), intent(inout) :: target_tensor

Target tensor with the desired device and dtype

logical, intent(in), optional :: non_blocking

Whether to perform asynchronous copy


Source Code

  subroutine torch_tensor_to(source_tensor, target_tensor, non_blocking)
    use, intrinsic :: iso_c_binding, only : c_bool, c_int, c_int64_t
    type(torch_tensor), intent(in) :: source_tensor      !! Source tensor to be moved
    type(torch_tensor), intent(inout) :: target_tensor   !! Target tensor with the desired device and dtype
    logical, optional, intent(in) :: non_blocking        !! Whether to perform asynchronous copy
    logical(c_bool) :: non_blocking_value
    integer(c_int) :: source_rank, target_rank, i
    integer(c_int64_t), pointer :: source_shape(:), target_shape(:)

    interface
      subroutine torch_tensor_to_c(source_tensor_c, target_tensor_c, non_blocking_c) &
          bind(c, name = 'torch_tensor_to')
        use, intrinsic :: iso_c_binding, only : c_bool, c_ptr
        implicit none
        type(c_ptr), value, intent(in) :: source_tensor_c
        type(c_ptr), value, intent(in) :: target_tensor_c
        logical(c_bool), value, intent(in) :: non_blocking_c
      end subroutine torch_tensor_to_c
    end interface

    ! Check for rank and shape consistency between the source and target tensors
    source_rank = source_tensor%get_rank()
    target_rank = target_tensor%get_rank()

    if (source_rank /= target_rank) then
      write(*,*) "Error in torch_tensor_to :: Cannot move source_tensor to target_tensor because the ranks do not match."
      write(*,*) "Source tensor rank:", source_rank, "Target tensor rank:", target_rank
      stop 1
    end if

    source_shape => source_tensor%get_shape()
    target_shape => target_tensor%get_shape()

    do i = 1, source_rank
      if (source_shape(i) /= target_shape(i)) then
        write(*,*) "Error in torch_tensor_to :: Cannot move source_tensor to target_tensor because the shapes do not match."
        write(*,*) "Dimension", i, "mismatch: source_tensor =", source_shape(i), &
            "Target =", target_shape(i)
        stop 1
      end if
    end do

    ! Process optional arguments
    if (present(non_blocking)) then
      non_blocking_value = non_blocking
    else
      non_blocking_value = .false.
    end if

    call torch_tensor_to_c(source_tensor%p, target_tensor%p, non_blocking_value)

  end subroutine torch_tensor_to