torch_optim_SGD Subroutine

public subroutine torch_optim_SGD(optim, parameters, learning_rate, momentum, weight_decay, dampening, nesterov)

Uses

  • proc~~torch_optim_sgd~~UsesGraph proc~torch_optim_sgd torch_optim_SGD iso_c_binding iso_c_binding proc~torch_optim_sgd->iso_c_binding iso_fortran_env iso_fortran_env proc~torch_optim_sgd->iso_fortran_env

Create an SGD optimizer

Arguments

Type IntentOptional Attributes Name
type(torch_optim), intent(out) :: optim

Optimizer we are creating

type(torch_tensor), intent(in), dimension(:) :: parameters

Array of parameter tensors

real(kind=real64), intent(in), optional :: learning_rate

learning rate for the optimization algorithm (default: 0.001)

real(kind=real64), intent(in), optional :: momentum

momentum for the optimization algorithm (default: 0.0)

real(kind=real64), intent(in), optional :: weight_decay

weight_decay for the optimization algorithm (default: 0.0)

real(kind=real64), intent(in), optional :: dampening

dampening for the optimization algorithm (default: 0.0)

logical, intent(in), optional :: nesterov

enable Nesterov momentum. Only applicable when momentum is non-zero. (default: .false.)


Source Code

  subroutine torch_optim_SGD(optim, parameters, learning_rate, momentum, weight_decay, &
                             dampening, nesterov)
    use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_loc, c_bool
    use, intrinsic :: iso_fortran_env, only : real64
    type(torch_optim), intent(out) :: optim  !! Optimizer we are creating
    type(torch_tensor), intent(in), dimension(:) :: parameters  !! Array of parameter tensors
    real(kind=real64), optional, intent(in) :: learning_rate  !! learning rate for the optimization algorithm (default: 0.001)
    real(kind=real64), optional, intent(in) :: momentum  !! momentum for the optimization algorithm (default: 0.0)
    real(kind=real64), optional, intent(in) :: dampening  !! dampening for the optimization algorithm (default: 0.0)
    real(kind=real64), optional, intent(in) :: weight_decay  !! weight_decay for the optimization algorithm (default: 0.0)
    logical, optional, intent(in) :: nesterov  !! enable Nesterov momentum. Only applicable when momentum is non-zero. (default: .false.)
    real(kind=real64) :: learning_rate_value  !! Resolved learning_rate value to be passed to the C interface
    real(kind=real64) :: momentum_value  !! Resolved momentum value to be passed to the C interface
    real(kind=real64) :: dampening_value  !! Resolved dampening value to be passed to the C interface
    real(kind=real64) :: weight_decay_value  !! Resolved weight_decay value to be passed to the C interface
    logical :: nesterov_value  !! Resolved nesterov value to be passed to the C interface

    integer(ftorch_int) :: i
    integer(c_int)      :: n_params
    type(c_ptr), dimension(size(parameters)), target  :: parameter_ptrs

    interface
      function torch_optim_SGD_c(parameters_c, n_params_c, learning_rate_c, momentum_c, &
                                 dampening_c, weight_decay_c, nesterov_c) &
          result(optim_c) bind(c, name = 'torch_optim_SGD')
        use, intrinsic :: iso_c_binding, only : c_bool, c_ptr, c_int, c_double
        implicit none
        type(c_ptr), value, intent(in) :: parameters_c
        integer(c_int), value, intent(in) :: n_params_c
        real(c_double), value, intent(in) :: learning_rate_c, momentum_c, dampening_c, &
                                             weight_decay_c
        logical(c_bool), value, intent(in) :: nesterov_c
        type(c_ptr) :: optim_c
      end function torch_optim_SGD_c
    end interface

    n_params = size(parameters)

    if (.not. present(learning_rate)) then
      learning_rate_value = 0.001_real64
    else
      learning_rate_value = learning_rate
    end if

    if (.not. present(momentum)) then
      momentum_value = 0.0_real64
    else
      momentum_value = momentum
    end if

    if (.not. present(dampening)) then
      dampening_value = 0.0_real64
    else
      dampening_value = dampening
    end if

    if (.not. present(weight_decay)) then
      weight_decay_value = 0.0_real64
    else
      weight_decay_value = weight_decay
    end if

    if (.not. present(nesterov)) then
      nesterov_value = .false.
    else
      nesterov_value = nesterov
    end if

    ! Assign array of pointers to the parameters
    do i = 1, n_params
      parameter_ptrs(i) = parameters(i)%p
    end do

    optim%p = torch_optim_SGD_c(c_loc(parameter_ptrs), n_params, &
                                learning_rate_value, momentum_value, dampening_value, &
                                weight_decay_value, logical(nesterov_value, c_bool))
  end subroutine torch_optim_SGD