Loads a TorchScript nn.module (pre-trained PyTorch model saved with TorchScript)
Type | Intent | Optional | Attributes | Name | ||
---|---|---|---|---|---|---|
type(torch_model), | intent(out) | :: | model |
Returned deserialized model |
||
character(len=*), | intent(in) | :: | filename |
Filename of saved TorchScript model |
||
integer(kind=c_int), | intent(in), | optional | :: | device_type |
Device type the tensor will live on ( |
|
integer(kind=c_int), | intent(in), | optional | :: | device_index |
device index to use for |
|
logical, | intent(in), | optional | :: | requires_grad |
Whether gradients need to be computed for the created tensor |
|
logical, | intent(in), | optional | :: | is_training |
Whether gradients need to be computed for the created tensor |
subroutine torch_model_load(model, filename, device_type, device_index, requires_grad, is_training) use, intrinsic :: iso_c_binding, only : c_bool, c_int, c_null_char type(torch_model), intent(out) :: model !! Returned deserialized model character(*), intent(in) :: filename !! Filename of saved TorchScript model integer(c_int), optional, intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor logical, optional, intent(in) :: is_training !! Whether gradients need to be computed for the created tensor integer(c_int) :: device_type_value integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor logical :: is_training_value !! Whether the model is being trained, rather than evaluated interface function torch_jit_load_c(filename, device_type, device_index, requires_grad, is_training) result(model) & bind(c, name = 'torch_jit_load') use, intrinsic :: iso_c_binding, only : c_bool, c_char, c_int, c_ptr character(c_char), intent(in) :: filename(*) integer(c_int), value, intent(in) :: device_type integer(c_int), value, intent(in) :: device_index logical(c_bool), value, intent(in) :: requires_grad logical(c_bool), value, intent(in) :: is_training type(c_ptr) :: model end function torch_jit_load_c end interface ! Process optional arguments if (present(device_type)) then device_type_value = device_type else device_type_value = torch_kCPU endif if (present(device_index)) then device_index_value = device_index else if (device_type_value == torch_kCPU) then device_index_value = -1 else device_index_value = 0 endif if (.not. present(requires_grad)) then requires_grad_value = .false. else requires_grad_value = requires_grad end if if (.not. present(is_training)) then is_training_value = .false. else is_training_value = is_training end if ! Need to append c_null_char at end of filename model%p = torch_jit_load_c(trim(adjustl(filename))//c_null_char, & device_type_value, device_index_value, & logical(requires_grad_value, c_bool), & logical(is_training_value, c_bool)) end subroutine torch_model_load