ftorch_optim.f90 Source File


This file depends on

sourcefile~~ftorch_optim.f90~~EfferentGraph sourcefile~ftorch_optim.f90 ftorch_optim.f90 sourcefile~ftorch_tensor.f90 ftorch_tensor.f90 sourcefile~ftorch_optim.f90->sourcefile~ftorch_tensor.f90 sourcefile~ftorch_types.f90 ftorch_types.f90 sourcefile~ftorch_optim.f90->sourcefile~ftorch_types.f90 sourcefile~ftorch_tensor.f90->sourcefile~ftorch_types.f90 sourcefile~ftorch_devices.f90 ftorch_devices.F90 sourcefile~ftorch_tensor.f90->sourcefile~ftorch_devices.f90

Files dependent on this one

sourcefile~~ftorch_optim.f90~~AfferentGraph sourcefile~ftorch_optim.f90 ftorch_optim.f90 sourcefile~ftorch.f90 ftorch.f90 sourcefile~ftorch.f90->sourcefile~ftorch_optim.f90

Source Code

!| Optimisers module for FTorch.
!
!  * License
!    FTorch is released under an MIT license.
!    See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE)
!    file for details.

module ftorch_optim

  use, intrinsic :: iso_c_binding, only: c_associated, c_null_ptr, c_ptr

  use ftorch_types, only: ftorch_int
  use ftorch_tensor, only: torch_tensor

  implicit none

  public

  ! ============================================================================
  ! --- Derived types
  ! ============================================================================

  !> Type for holding a torch optimizer.
  type torch_optim
    type(c_ptr) :: p = c_null_ptr  !! pointer to the optimizer in memory
  contains
    procedure :: step => torch_optim_step
    procedure :: zero_grad => torch_optim_zero_grad
    final :: torch_optim_delete
  end type torch_optim

contains

  ! ============================================================================
  ! --- Procedures for using Optimizers
  ! ============================================================================

  !> Zero Gradients on tensors associated with a Torch optimizer
  subroutine torch_optim_zero_grad(optim)
    class(torch_optim), intent(in) :: optim  !! Optimizer to zero gradients for

    interface
      subroutine torch_optim_zero_grad_c(optim_c) &
          bind(c, name = 'torch_optim_zero_grad')
        use, intrinsic :: iso_c_binding, only : c_ptr
        implicit none
        type(c_ptr), value, intent(in) :: optim_c
      end subroutine torch_optim_zero_grad_c
    end interface

    call torch_optim_zero_grad_c(optim%p)
  end subroutine torch_optim_zero_grad

  !> Step a Torch optimizer
  subroutine torch_optim_step(optim)
    class(torch_optim), intent(in) :: optim  !! Optimizer to step

    interface
      subroutine torch_optim_step_c(optim_c) &
          bind(c, name = 'torch_optim_step')
        use, intrinsic :: iso_c_binding, only : c_ptr
        implicit none
        type(c_ptr), value, intent(in) :: optim_c
      end subroutine torch_optim_step_c
    end interface

    call torch_optim_step_c(optim%p)
  end subroutine torch_optim_step

  !> Deallocate a Torch optimizer
  subroutine torch_optim_delete(optim)
    type(torch_optim), intent(inout) :: optim  !! Optimizer to deallocate

    interface
      subroutine torch_optim_delete_c(optim_c) &
          bind(c, name = 'torch_optim_delete')
        use, intrinsic :: iso_c_binding, only : c_ptr
        implicit none
        type(c_ptr), value, intent(in) :: optim_c
      end subroutine torch_optim_delete_c
    end interface

    ! Call the destructor, if it hasn't already been called
    if (c_associated(optim%p)) then
      call torch_optim_delete_c(optim%p)
      optim%p = c_null_ptr
    end if
  end subroutine torch_optim_delete

  ! ============================================================================
  ! --- Procedures for creating specific Optimizers
  ! ============================================================================

  !> Create an SGD optimizer
  subroutine torch_optim_SGD(optim, parameters, learning_rate, momentum, weight_decay, &
                             dampening, nesterov)
    use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_loc, c_bool
    use, intrinsic :: iso_fortran_env, only : real64
    type(torch_optim), intent(out) :: optim  !! Optimizer we are creating
    type(torch_tensor), intent(in), dimension(:) :: parameters  !! Array of parameter tensors
    real(kind=real64), optional, intent(in) :: learning_rate  !! learning rate for the optimization algorithm (default: 0.001)
    real(kind=real64), optional, intent(in) :: momentum  !! momentum for the optimization algorithm (default: 0.0)
    real(kind=real64), optional, intent(in) :: dampening  !! dampening for the optimization algorithm (default: 0.0)
    real(kind=real64), optional, intent(in) :: weight_decay  !! weight_decay for the optimization algorithm (default: 0.0)
    logical, optional, intent(in) :: nesterov  !! enable Nesterov momentum. Only applicable when momentum is non-zero. (default: .false.)
    real(kind=real64) :: learning_rate_value  !! Resolved learning_rate value to be passed to the C interface
    real(kind=real64) :: momentum_value  !! Resolved momentum value to be passed to the C interface
    real(kind=real64) :: dampening_value  !! Resolved dampening value to be passed to the C interface
    real(kind=real64) :: weight_decay_value  !! Resolved weight_decay value to be passed to the C interface
    logical :: nesterov_value  !! Resolved nesterov value to be passed to the C interface

    integer(ftorch_int) :: i
    integer(c_int)      :: n_params
    type(c_ptr), dimension(size(parameters)), target  :: parameter_ptrs

    interface
      function torch_optim_SGD_c(parameters_c, n_params_c, learning_rate_c, momentum_c, &
                                 dampening_c, weight_decay_c, nesterov_c) &
          result(optim_c) bind(c, name = 'torch_optim_SGD')
        use, intrinsic :: iso_c_binding, only : c_bool, c_ptr, c_int, c_double
        implicit none
        type(c_ptr), value, intent(in) :: parameters_c
        integer(c_int), value, intent(in) :: n_params_c
        real(c_double), value, intent(in) :: learning_rate_c, momentum_c, dampening_c, &
                                             weight_decay_c
        logical(c_bool), value, intent(in) :: nesterov_c
        type(c_ptr) :: optim_c
      end function torch_optim_SGD_c
    end interface

    n_params = size(parameters)

    if (.not. present(learning_rate)) then
      learning_rate_value = 0.001_real64
    else
      learning_rate_value = learning_rate
    end if

    if (.not. present(momentum)) then
      momentum_value = 0.0_real64
    else
      momentum_value = momentum
    end if

    if (.not. present(dampening)) then
      dampening_value = 0.0_real64
    else
      dampening_value = dampening
    end if

    if (.not. present(weight_decay)) then
      weight_decay_value = 0.0_real64
    else
      weight_decay_value = weight_decay
    end if

    if (.not. present(nesterov)) then
      nesterov_value = .false.
    else
      nesterov_value = nesterov
    end if

    ! Assign array of pointers to the parameters
    do i = 1, n_params
      parameter_ptrs(i) = parameters(i)%p
    end do

    optim%p = torch_optim_SGD_c(c_loc(parameter_ptrs), n_params, &
                                learning_rate_value, momentum_value, dampening_value, &
                                weight_decay_value, logical(nesterov_value, c_bool))
  end subroutine torch_optim_SGD

  !> Create an Adam optimizer
  subroutine torch_optim_Adam(optim, parameters, learning_rate, beta_1, beta_2, &
                              eps, weight_decay, amsgrad)
    use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_loc, c_bool
    use, intrinsic :: iso_fortran_env, only : real64
    type(torch_optim), intent(out) :: optim  !! Optimizer we are creating
    type(torch_tensor), intent(in), dimension(:) :: parameters  !! Array of parameter tensors
    real(kind=real64), optional, intent(in) :: learning_rate  !! learning rate for the optimization algorithm (default: 0.001)
    real(kind=real64), optional, intent(in) :: beta_1  !! beta 1 for the optimization algorithm (default: 0.9)
    real(kind=real64), optional, intent(in) :: beta_2  !! beta 2 for the optimization algorithm (default: 0.999)
    real(kind=real64), optional, intent(in) :: eps  !! eps for the optimization algorithm (default: 1.0e-8)
    real(kind=real64), optional, intent(in) :: weight_decay  !! weight_decay for the optimization algorithm (default: 0.0)
    logical, optional, intent(in) :: amsgrad  !! enable AMSGrad variant (default: .false.)
    real(kind=real64) :: learning_rate_value  !! Resolved learning_rate value to be passed to the C interface
    real(kind=real64) :: beta_1_value  !! Resolved beta_1 value to be passed to the C interface
    real(kind=real64) :: beta_2_value  !! Resolved beta 2 value to be passed to the C interface
    real(kind=real64) :: eps_value  !! Resolved eps value to be passed to the C interface
    real(kind=real64) :: weight_decay_value  !! Resolved weight_decay value to be passed to the C interface
    logical :: amsgrad_value  !! Resolved amsgrad value to be passed to the C interface

    integer(ftorch_int) :: i
    integer(c_int)      :: n_params
    type(c_ptr), dimension(size(parameters)), target  :: parameter_ptrs

    interface
      function torch_optim_Adam_c(parameters_c, n_params_c, learning_rate_c, &
                                  beta_1_c, beta_2_c, eps_c, weight_decay_c, amsgrad_c) &
          result(optim_c) bind(c, name = 'torch_optim_Adam')
        use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_bool
        implicit none
        type(c_ptr), value, intent(in) :: parameters_c
        integer(c_int), value, intent(in) :: n_params_c
        real(c_double), value, intent(in) :: learning_rate_c, beta_1_c, beta_2_c, &
                                             eps_c, weight_decay_c
        logical(c_bool), value, intent(in) :: amsgrad_c
        type(c_ptr) :: optim_c
      end function torch_optim_Adam_c
    end interface

    n_params = size(parameters)

    if (.not. present(learning_rate)) then
      learning_rate_value = 0.001_real64
    else
      learning_rate_value = learning_rate
    end if

    if (.not. present(beta_1)) then
      beta_1_value = 0.9_real64
    else
      beta_1_value = beta_1
    end if

    if (.not. present(beta_2)) then
      beta_2_value = 0.999_real64
    else
      beta_2_value = beta_2
    end if

    if (.not. present(eps)) then
      eps_value = 1.0e-8_real64
    else
      eps_value = eps
    end if

    if (.not. present(weight_decay)) then
      weight_decay_value = 0.0_real64
    else
      weight_decay_value = weight_decay
    end if

    if (.not. present(amsgrad)) then
      amsgrad_value = .false.
    else
      amsgrad_value = amsgrad
    end if

    ! Assign array of pointers to the parameters
    do i = 1, n_params
      parameter_ptrs(i) = parameters(i)%p
    end do

    optim%p = torch_optim_Adam_c(c_loc(parameter_ptrs), n_params, learning_rate_value, &
                                 beta_1_value, beta_2_value, eps_value, &
                                 weight_decay_value, logical(amsgrad_value, c_bool))
  end subroutine torch_optim_Adam

  !> Create an AdamW optimizer
  subroutine torch_optim_AdamW(optim, parameters, learning_rate, beta_1, beta_2, &
                               eps, weight_decay, amsgrad)
    use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_loc, c_bool
    use, intrinsic :: iso_fortran_env, only : real64
    type(torch_optim), intent(out) :: optim  !! Optimizer we are creating
    type(torch_tensor), intent(in), dimension(:) :: parameters  !! Array of parameter tensors
    real(kind=real64), optional, intent(in) :: learning_rate  !! learning rate for the optimization algorithm (default: 0.001)
    real(kind=real64), optional, intent(in) :: beta_1  !! beta 1 for the optimization algorithm (default: 0.9)
    real(kind=real64), optional, intent(in) :: beta_2  !! beta 2 for the optimization algorithm (default: 0.999)
    real(kind=real64), optional, intent(in) :: eps  !! eps for the optimization algorithm (default: 1.0e-8)
    real(kind=real64), optional, intent(in) :: weight_decay  !! weight_decay for the optimization algorithm (default: 0.01)
    logical, optional, intent(in) :: amsgrad  !! enable AMSGrad variant (default: .false.)
    real(kind=real64) :: learning_rate_value  !! Resolved learning_rate value to be passed to the C interface
    real(kind=real64) :: beta_1_value  !! Resolved beta_1 value to be passed to the C interface
    real(kind=real64) :: beta_2_value  !! Resolved beta 2 value to be passed to the C interface
    real(kind=real64) :: eps_value  !! Resolved eps value to be passed to the C interface
    real(kind=real64) :: weight_decay_value  !! Resolved weight_decay value to be passed to the C interface
    logical :: amsgrad_value  !! Resolved amsgrad value to be passed to the C interface

    integer(ftorch_int) :: i
    integer(c_int)      :: n_params
    type(c_ptr), dimension(size(parameters)), target  :: parameter_ptrs

    interface
      function torch_optim_AdamW_c(parameters_c, n_params_c, learning_rate_c, &
                                   beta_1_c, beta_2_c, weight_decay_c, eps_c, amsgrad_c) &
          result(optim_c) bind(c, name = 'torch_optim_AdamW')
        use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_bool
        implicit none
        type(c_ptr), value, intent(in) :: parameters_c
        integer(c_int), value, intent(in) :: n_params_c
        real(c_double), value, intent(in) :: learning_rate_c, beta_1_c, beta_2_c, &
                                             eps_c, weight_decay_c
        logical(c_bool), value, intent(in) :: amsgrad_c
        type(c_ptr) :: optim_c
      end function torch_optim_AdamW_c
    end interface

    n_params = size(parameters)

    if (.not. present(learning_rate)) then
      learning_rate_value = 0.001_real64
    else
      learning_rate_value = learning_rate
    end if

    if (.not. present(beta_1)) then
      beta_1_value = 0.9_real64
    else
      beta_1_value = beta_1
    end if

    if (.not. present(beta_2)) then
      beta_2_value = 0.999_real64
    else
      beta_2_value = beta_2
    end if

    if (.not. present(eps)) then
      eps_value = 1.0e-8_real64
    else
      eps_value = eps
    end if

    if (.not. present(weight_decay)) then
      weight_decay_value = 0.01_real64  ! Different default for AdamW to Adam!
    else
      weight_decay_value = weight_decay
    end if

    if (.not. present(amsgrad)) then
      amsgrad_value = .false.
    else
      amsgrad_value = amsgrad
    end if

    ! Assign array of pointers to the parameters
    do i = 1, n_params
      parameter_ptrs(i) = parameters(i)%p
    end do

     optim%p = torch_optim_AdamW_c(c_loc(parameter_ptrs), n_params, learning_rate_value, &
                                    beta_1_value, beta_2_value, eps_value, &
                                    weight_decay_value, logical(amsgrad_value, c_bool))
  end subroutine torch_optim_AdamW

end module ftorch_optim