torch_model_load Subroutine

public subroutine torch_model_load(model, filename, device_type, device_index, requires_grad, is_training)

Loads a TorchScript nn.module (pre-trained PyTorch model saved with TorchScript)

Arguments

Type IntentOptional Attributes Name
type(torch_model), intent(out) :: model

Returned deserialized model

character(len=*), intent(in) :: filename

Filename of saved TorchScript model

integer(kind=c_int), intent(in), optional :: device_type

Device type the tensor will live on (torch_kCPU or torch_kCUDA)

integer(kind=c_int), intent(in), optional :: device_index

device index to use for torch_kCUDA case

logical, intent(in), optional :: requires_grad

Whether gradients need to be computed for the created tensor

logical, intent(in), optional :: is_training

Whether gradients need to be computed for the created tensor


Source Code

  subroutine torch_model_load(model, filename, device_type, device_index, &
                              requires_grad, is_training)
    use, intrinsic :: iso_c_binding, only : c_bool, c_int, c_null_char
    type(torch_model), intent(out)       :: model         !! Returned deserialized model
    character(*), intent(in)             :: filename      !! Filename of saved TorchScript model
    integer(c_int), optional, intent(in) :: device_type   !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
    integer(c_int), optional, intent(in) :: device_index  !! device index to use for `torch_kCUDA` case
    logical, optional, intent(in) :: requires_grad  !! Whether gradients need to be computed for the created tensor
    logical, optional, intent(in) :: is_training    !! Whether gradients need to be computed for the created tensor
    integer(c_int) :: device_type_value
    integer(c_int) :: device_index_value
    logical :: requires_grad_value  !! Whether gradients need to be computed for the created tensor
    logical :: is_training_value  !! Whether the model is being trained, rather than evaluated

    interface
      function torch_jit_load_c(filename, device_type, device_index, &
                                requires_grad, is_training) result(model) &
          bind(c, name = 'torch_jit_load')
        use, intrinsic :: iso_c_binding, only : c_bool, c_char, c_int, c_ptr
        implicit none
        character(c_char), intent(in) :: filename(*)
        integer(c_int), value, intent(in)    :: device_type
        integer(c_int), value, intent(in)    :: device_index
        logical(c_bool), value, intent(in) :: requires_grad
        logical(c_bool), value, intent(in) :: is_training
        type(c_ptr)                   :: model
      end function torch_jit_load_c
    end interface

    ! Process optional arguments
    if (present(device_type)) then
      device_type_value = device_type
    else
      device_type_value = torch_kCPU
    endif
    if (present(device_index)) then
      device_index_value = device_index
    else if (device_type_value == torch_kCPU) then
      device_index_value = -1
    else
      device_index_value = 0
    endif

    if (.not. present(requires_grad)) then
      requires_grad_value = .false.
    else
      requires_grad_value = requires_grad
    end if

    if (.not. present(is_training)) then
      is_training_value = .false.
    else
      is_training_value = is_training
    end if

    ! Need to append c_null_char at end of filename
    model%p = torch_jit_load_c(trim(adjustl(filename))//c_null_char,           &
                                device_type_value, device_index_value,         &
                                logical(requires_grad_value, c_bool),          &
                                logical(is_training_value, c_bool))
  end subroutine torch_model_load