torch_model_parameters Subroutine

public subroutine torch_model_parameters(model, output_tensors)

Uses

  • proc~~torch_model_parameters~~UsesGraph proc~torch_model_parameters torch_model_parameters iso_c_binding iso_c_binding proc~torch_model_parameters->iso_c_binding

Extracts the parameters from a model

Arguments

Type IntentOptional Attributes Name
type(torch_model), intent(in) :: model

Model

type(torch_tensor), intent(inout), dimension(:) :: output_tensors

Returned output tensors


Source Code

  subroutine torch_model_parameters(model, output_tensors)
    use, intrinsic :: iso_c_binding, only : c_associated, c_int, c_loc, c_ptr
    type(torch_model), intent(in) :: model  !! Model
    type(torch_tensor), intent(inout), dimension(:) :: output_tensors  !! Returned output tensors

    integer(c_int) :: n_outputs
    integer :: i
    type(c_ptr), dimension(size(output_tensors)), target :: output_ptrs

    interface
      subroutine torch_jit_model_parameters_c(model_c, output_tensors_c, n_outputs_c) &
          bind(c, name = 'torch_jit_module_parameters')
        use, intrinsic :: iso_c_binding, only : c_ptr, c_int
        implicit none
        type(c_ptr), value, intent(in) :: model_c
        type(c_ptr), value, intent(in) :: output_tensors_c
        integer(c_int), value, intent(in) :: n_outputs_c
      end subroutine torch_jit_model_parameters_c
    end interface

    ! Check the tensors aren't already associated
    n_outputs = size(output_tensors)
    do i = 1, n_outputs
      if (c_associated(output_tensors(i)%p)) then
        write(*,*) "Error :: tensor pointer will be lost in call to torch_model_parameters"
        stop 1
      end if
    end do

    ! Get the parameters that were created during model construction
    call torch_jit_model_parameters_c(model%p, c_loc(output_ptrs), n_outputs)

    ! Copy updated pointers back to output tensors
    do i = 1, n_outputs
      output_tensors(i)%p = output_ptrs(i)
    end do
  end subroutine torch_model_parameters