Moves a source_tensor tensor to a target tensor's device and dtype
Type | Intent | Optional | Attributes | Name | ||
---|---|---|---|---|---|---|
type(torch_tensor), | intent(in) | :: | source_tensor |
Source tensor to be moved |
||
type(torch_tensor), | intent(inout) | :: | target_tensor |
Target tensor with the desired device and dtype |
||
logical, | intent(in), | optional | :: | non_blocking |
Whether to perform asynchronous copy |
subroutine torch_tensor_to(source_tensor, target_tensor, non_blocking) use, intrinsic :: iso_c_binding, only : c_bool, c_int, c_int64_t type(torch_tensor), intent(in) :: source_tensor !! Source tensor to be moved type(torch_tensor), intent(inout) :: target_tensor !! Target tensor with the desired device and dtype logical, optional, intent(in) :: non_blocking !! Whether to perform asynchronous copy logical(c_bool) :: non_blocking_value integer(c_int) :: source_rank, target_rank, i integer(c_int64_t), pointer :: source_shape(:), target_shape(:) interface subroutine torch_tensor_to_c(source_tensor_c, target_tensor_c, non_blocking_c) & bind(c, name = 'torch_tensor_to') use, intrinsic :: iso_c_binding, only : c_bool, c_ptr implicit none type(c_ptr), value, intent(in) :: source_tensor_c type(c_ptr), value, intent(in) :: target_tensor_c logical(c_bool), value, intent(in) :: non_blocking_c end subroutine torch_tensor_to_c end interface ! Check for rank and shape consistency between the source and target tensors source_rank = source_tensor%get_rank() target_rank = target_tensor%get_rank() if (source_rank /= target_rank) then write(*,*) "Error in torch_tensor_to :: Cannot move source_tensor to target_tensor because the ranks do not match." write(*,*) "Source tensor rank:", source_rank, "Target tensor rank:", target_rank stop 1 end if source_shape => source_tensor%get_shape() target_shape => target_tensor%get_shape() do i = 1, source_rank if (source_shape(i) /= target_shape(i)) then write(*,*) "Error in torch_tensor_to :: Cannot move source_tensor to target_tensor because the shapes do not match." write(*,*) "Dimension", i, "mismatch: source_tensor =", source_shape(i), & "Target =", target_shape(i) stop 1 end if end do ! Process optional arguments if (present(non_blocking)) then non_blocking_value = non_blocking else non_blocking_value = .false. end if call torch_tensor_to_c(source_tensor%p, target_tensor%p, non_blocking_value) end subroutine torch_tensor_to