Overloads subtraction operator for two tensors.
Type | Intent | Optional | Attributes | Name | ||
---|---|---|---|---|---|---|
type(torch_tensor), | intent(in) | :: | tensor1 | |||
type(torch_tensor), | intent(in) | :: | tensor2 |
function torch_tensor_subtract(tensor1, tensor2) result(output) use, intrinsic :: iso_c_binding, only : c_associated type(torch_tensor), intent(in) :: tensor1 type(torch_tensor), intent(in) :: tensor2 type(torch_tensor) :: output interface subroutine torch_tensor_subtract_c(output_c, tensor1_c, tensor2_c) & bind(c, name = 'torch_tensor_subtract') use, intrinsic :: iso_c_binding, only : c_ptr implicit none type(c_ptr), value, intent(in) :: output_c type(c_ptr), value, intent(in) :: tensor1_c type(c_ptr), value, intent(in) :: tensor2_c end subroutine torch_tensor_subtract_c end interface if (tensor1%get_device_type() /= tensor2%get_device_type()) then write(*,*) "Error :: cannot subtract tensors with different device types" stop 1 end if if (.not. c_associated(output%p)) then call torch_tensor_empty(output, tensor1%get_rank(), tensor1%get_shape(), & tensor1%get_dtype(), tensor1%get_device_type(), & device_index=tensor1%get_device_index(), & requires_grad=tensor1%requires_grad()) end if call torch_tensor_subtract_c(output%p, tensor1%p, tensor2%p) end function torch_tensor_subtract