torch_optim_AdamW Subroutine

public subroutine torch_optim_AdamW(optim, parameters, learning_rate, beta_1, beta_2, eps, weight_decay, amsgrad)

Uses

  • proc~~torch_optim_adamw~~UsesGraph proc~torch_optim_adamw torch_optim_AdamW iso_c_binding iso_c_binding proc~torch_optim_adamw->iso_c_binding iso_fortran_env iso_fortran_env proc~torch_optim_adamw->iso_fortran_env

Create an AdamW optimizer

Arguments

Type IntentOptional Attributes Name
type(torch_optim), intent(out) :: optim

Optimizer we are creating

type(torch_tensor), intent(in), dimension(:) :: parameters

Array of parameter tensors

real(kind=real64), intent(in), optional :: learning_rate

learning rate for the optimization algorithm (default: 0.001)

real(kind=real64), intent(in), optional :: beta_1

beta 1 for the optimization algorithm (default: 0.9)

real(kind=real64), intent(in), optional :: beta_2

beta 2 for the optimization algorithm (default: 0.999)

real(kind=real64), intent(in), optional :: eps

eps for the optimization algorithm (default: 1.0e-8)

real(kind=real64), intent(in), optional :: weight_decay

weight_decay for the optimization algorithm (default: 0.01)

logical, intent(in), optional :: amsgrad

enable AMSGrad variant (default: .false.)


Source Code

  subroutine torch_optim_AdamW(optim, parameters, learning_rate, beta_1, beta_2, &
                               eps, weight_decay, amsgrad)
    use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_loc, c_bool
    use, intrinsic :: iso_fortran_env, only : real64
    type(torch_optim), intent(out) :: optim  !! Optimizer we are creating
    type(torch_tensor), intent(in), dimension(:) :: parameters  !! Array of parameter tensors
    real(kind=real64), optional, intent(in) :: learning_rate  !! learning rate for the optimization algorithm (default: 0.001)
    real(kind=real64), optional, intent(in) :: beta_1  !! beta 1 for the optimization algorithm (default: 0.9)
    real(kind=real64), optional, intent(in) :: beta_2  !! beta 2 for the optimization algorithm (default: 0.999)
    real(kind=real64), optional, intent(in) :: eps  !! eps for the optimization algorithm (default: 1.0e-8)
    real(kind=real64), optional, intent(in) :: weight_decay  !! weight_decay for the optimization algorithm (default: 0.01)
    logical, optional, intent(in) :: amsgrad  !! enable AMSGrad variant (default: .false.)
    real(kind=real64) :: learning_rate_value  !! Resolved learning_rate value to be passed to the C interface
    real(kind=real64) :: beta_1_value  !! Resolved beta_1 value to be passed to the C interface
    real(kind=real64) :: beta_2_value  !! Resolved beta 2 value to be passed to the C interface
    real(kind=real64) :: eps_value  !! Resolved eps value to be passed to the C interface
    real(kind=real64) :: weight_decay_value  !! Resolved weight_decay value to be passed to the C interface
    logical :: amsgrad_value  !! Resolved amsgrad value to be passed to the C interface

    integer(ftorch_int) :: i
    integer(c_int)      :: n_params
    type(c_ptr), dimension(size(parameters)), target  :: parameter_ptrs

    interface
      function torch_optim_AdamW_c(parameters_c, n_params_c, learning_rate_c, &
                                   beta_1_c, beta_2_c, weight_decay_c, eps_c, amsgrad_c) &
          result(optim_c) bind(c, name = 'torch_optim_AdamW')
        use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_bool
        implicit none
        type(c_ptr), value, intent(in) :: parameters_c
        integer(c_int), value, intent(in) :: n_params_c
        real(c_double), value, intent(in) :: learning_rate_c, beta_1_c, beta_2_c, &
                                             eps_c, weight_decay_c
        logical(c_bool), value, intent(in) :: amsgrad_c
        type(c_ptr) :: optim_c
      end function torch_optim_AdamW_c
    end interface

    n_params = size(parameters)

    if (.not. present(learning_rate)) then
      learning_rate_value = 0.001_real64
    else
      learning_rate_value = learning_rate
    end if

    if (.not. present(beta_1)) then
      beta_1_value = 0.9_real64
    else
      beta_1_value = beta_1
    end if

    if (.not. present(beta_2)) then
      beta_2_value = 0.999_real64
    else
      beta_2_value = beta_2
    end if

    if (.not. present(eps)) then
      eps_value = 1.0e-8_real64
    else
      eps_value = eps
    end if

    if (.not. present(weight_decay)) then
      weight_decay_value = 0.01_real64  ! Different default for AdamW to Adam!
    else
      weight_decay_value = weight_decay
    end if

    if (.not. present(amsgrad)) then
      amsgrad_value = .false.
    else
      amsgrad_value = amsgrad
    end if

    ! Assign array of pointers to the parameters
    do i = 1, n_params
      parameter_ptrs(i) = parameters(i)%p
    end do

     optim%p = torch_optim_AdamW_c(c_loc(parameter_ptrs), n_params, learning_rate_value, &
                                    beta_1_value, beta_2_value, eps_value, &
                                    weight_decay_value, logical(amsgrad_value, c_bool))
  end subroutine torch_optim_AdamW