Create an AdamW optimizer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| type(torch_optim), | intent(out) | :: | optim |
Optimizer we are creating |
||
| type(torch_tensor), | intent(in), | dimension(:) | :: | parameters |
Array of parameter tensors |
|
| real(kind=real64), | intent(in), | optional | :: | learning_rate |
learning rate for the optimization algorithm (default: 0.001) |
|
| real(kind=real64), | intent(in), | optional | :: | beta_1 |
beta 1 for the optimization algorithm (default: 0.9) |
|
| real(kind=real64), | intent(in), | optional | :: | beta_2 |
beta 2 for the optimization algorithm (default: 0.999) |
|
| real(kind=real64), | intent(in), | optional | :: | eps |
eps for the optimization algorithm (default: 1.0e-8) |
|
| real(kind=real64), | intent(in), | optional | :: | weight_decay |
weight_decay for the optimization algorithm (default: 0.01) |
|
| logical, | intent(in), | optional | :: | amsgrad |
enable AMSGrad variant (default: .false.) |
subroutine torch_optim_AdamW(optim, parameters, learning_rate, beta_1, beta_2, & eps, weight_decay, amsgrad) use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_loc, c_bool use, intrinsic :: iso_fortran_env, only : real64 type(torch_optim), intent(out) :: optim !! Optimizer we are creating type(torch_tensor), intent(in), dimension(:) :: parameters !! Array of parameter tensors real(kind=real64), optional, intent(in) :: learning_rate !! learning rate for the optimization algorithm (default: 0.001) real(kind=real64), optional, intent(in) :: beta_1 !! beta 1 for the optimization algorithm (default: 0.9) real(kind=real64), optional, intent(in) :: beta_2 !! beta 2 for the optimization algorithm (default: 0.999) real(kind=real64), optional, intent(in) :: eps !! eps for the optimization algorithm (default: 1.0e-8) real(kind=real64), optional, intent(in) :: weight_decay !! weight_decay for the optimization algorithm (default: 0.01) logical, optional, intent(in) :: amsgrad !! enable AMSGrad variant (default: .false.) real(kind=real64) :: learning_rate_value !! Resolved learning_rate value to be passed to the C interface real(kind=real64) :: beta_1_value !! Resolved beta_1 value to be passed to the C interface real(kind=real64) :: beta_2_value !! Resolved beta 2 value to be passed to the C interface real(kind=real64) :: eps_value !! Resolved eps value to be passed to the C interface real(kind=real64) :: weight_decay_value !! Resolved weight_decay value to be passed to the C interface logical :: amsgrad_value !! Resolved amsgrad value to be passed to the C interface integer(ftorch_int) :: i integer(c_int) :: n_params type(c_ptr), dimension(size(parameters)), target :: parameter_ptrs interface function torch_optim_AdamW_c(parameters_c, n_params_c, learning_rate_c, & beta_1_c, beta_2_c, weight_decay_c, eps_c, amsgrad_c) & result(optim_c) bind(c, name = 'torch_optim_AdamW') use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_bool implicit none type(c_ptr), value, intent(in) :: parameters_c integer(c_int), value, intent(in) :: n_params_c real(c_double), value, intent(in) :: learning_rate_c, beta_1_c, beta_2_c, & eps_c, weight_decay_c logical(c_bool), value, intent(in) :: amsgrad_c type(c_ptr) :: optim_c end function torch_optim_AdamW_c end interface n_params = size(parameters) if (.not. present(learning_rate)) then learning_rate_value = 0.001_real64 else learning_rate_value = learning_rate end if if (.not. present(beta_1)) then beta_1_value = 0.9_real64 else beta_1_value = beta_1 end if if (.not. present(beta_2)) then beta_2_value = 0.999_real64 else beta_2_value = beta_2 end if if (.not. present(eps)) then eps_value = 1.0e-8_real64 else eps_value = eps end if if (.not. present(weight_decay)) then weight_decay_value = 0.01_real64 ! Different default for AdamW to Adam! else weight_decay_value = weight_decay end if if (.not. present(amsgrad)) then amsgrad_value = .false. else amsgrad_value = amsgrad end if ! Assign array of pointers to the parameters do i = 1, n_params parameter_ptrs(i) = parameters(i)%p end do optim%p = torch_optim_AdamW_c(c_loc(parameter_ptrs), n_params, learning_rate_value, & beta_1_value, beta_2_value, eps_value, & weight_decay_value, logical(amsgrad_value, c_bool)) end subroutine torch_optim_AdamW