!| Optimisers module for FTorch. ! ! * License ! FTorch is released under an MIT license. ! See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE) ! file for details. module ftorch_optim use, intrinsic :: iso_c_binding, only: c_associated, c_null_ptr, c_ptr use ftorch_types, only: ftorch_int use ftorch_tensor, only: torch_tensor implicit none public ! ============================================================================ ! --- Derived types ! ============================================================================ !> Type for holding a torch optimizer. type torch_optim type(c_ptr) :: p = c_null_ptr !! pointer to the optimizer in memory contains procedure :: step => torch_optim_step procedure :: zero_grad => torch_optim_zero_grad final :: torch_optim_delete end type torch_optim contains ! ============================================================================ ! --- Procedures for using Optimizers ! ============================================================================ !> Zero Gradients on tensors associated with a Torch optimizer subroutine torch_optim_zero_grad(optim) class(torch_optim), intent(in) :: optim !! Optimizer to zero gradients for interface subroutine torch_optim_zero_grad_c(optim_c) & bind(c, name = 'torch_optim_zero_grad') use, intrinsic :: iso_c_binding, only : c_ptr implicit none type(c_ptr), value, intent(in) :: optim_c end subroutine torch_optim_zero_grad_c end interface call torch_optim_zero_grad_c(optim%p) end subroutine torch_optim_zero_grad !> Step a Torch optimizer subroutine torch_optim_step(optim) class(torch_optim), intent(in) :: optim !! Optimizer to step interface subroutine torch_optim_step_c(optim_c) & bind(c, name = 'torch_optim_step') use, intrinsic :: iso_c_binding, only : c_ptr implicit none type(c_ptr), value, intent(in) :: optim_c end subroutine torch_optim_step_c end interface call torch_optim_step_c(optim%p) end subroutine torch_optim_step !> Deallocate a Torch optimizer subroutine torch_optim_delete(optim) type(torch_optim), intent(inout) :: optim !! Optimizer to deallocate interface subroutine torch_optim_delete_c(optim_c) & bind(c, name = 'torch_optim_delete') use, intrinsic :: iso_c_binding, only : c_ptr implicit none type(c_ptr), value, intent(in) :: optim_c end subroutine torch_optim_delete_c end interface ! Call the destructor, if it hasn't already been called if (c_associated(optim%p)) then call torch_optim_delete_c(optim%p) optim%p = c_null_ptr end if end subroutine torch_optim_delete ! ============================================================================ ! --- Procedures for creating specific Optimizers ! ============================================================================ !> Create an SGD optimizer subroutine torch_optim_SGD(optim, parameters, learning_rate, momentum, weight_decay, & dampening, nesterov) use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_loc, c_bool use, intrinsic :: iso_fortran_env, only : real64 type(torch_optim), intent(out) :: optim !! Optimizer we are creating type(torch_tensor), intent(in), dimension(:) :: parameters !! Array of parameter tensors real(kind=real64), optional, intent(in) :: learning_rate !! learning rate for the optimization algorithm (default: 0.001) real(kind=real64), optional, intent(in) :: momentum !! momentum for the optimization algorithm (default: 0.0) real(kind=real64), optional, intent(in) :: dampening !! dampening for the optimization algorithm (default: 0.0) real(kind=real64), optional, intent(in) :: weight_decay !! weight_decay for the optimization algorithm (default: 0.0) logical, optional, intent(in) :: nesterov !! enable Nesterov momentum. Only applicable when momentum is non-zero. (default: .false.) real(kind=real64) :: learning_rate_value !! Resolved learning_rate value to be passed to the C interface real(kind=real64) :: momentum_value !! Resolved momentum value to be passed to the C interface real(kind=real64) :: dampening_value !! Resolved dampening value to be passed to the C interface real(kind=real64) :: weight_decay_value !! Resolved weight_decay value to be passed to the C interface logical :: nesterov_value !! Resolved nesterov value to be passed to the C interface integer(ftorch_int) :: i integer(c_int) :: n_params type(c_ptr), dimension(size(parameters)), target :: parameter_ptrs interface function torch_optim_SGD_c(parameters_c, n_params_c, learning_rate_c, momentum_c, & dampening_c, weight_decay_c, nesterov_c) & result(optim_c) bind(c, name = 'torch_optim_SGD') use, intrinsic :: iso_c_binding, only : c_bool, c_ptr, c_int, c_double implicit none type(c_ptr), value, intent(in) :: parameters_c integer(c_int), value, intent(in) :: n_params_c real(c_double), value, intent(in) :: learning_rate_c, momentum_c, dampening_c, & weight_decay_c logical(c_bool), value, intent(in) :: nesterov_c type(c_ptr) :: optim_c end function torch_optim_SGD_c end interface n_params = size(parameters) if (.not. present(learning_rate)) then learning_rate_value = 0.001_real64 else learning_rate_value = learning_rate end if if (.not. present(momentum)) then momentum_value = 0.0_real64 else momentum_value = momentum end if if (.not. present(dampening)) then dampening_value = 0.0_real64 else dampening_value = dampening end if if (.not. present(weight_decay)) then weight_decay_value = 0.0_real64 else weight_decay_value = weight_decay end if if (.not. present(nesterov)) then nesterov_value = .false. else nesterov_value = nesterov end if ! Assign array of pointers to the parameters do i = 1, n_params parameter_ptrs(i) = parameters(i)%p end do optim%p = torch_optim_SGD_c(c_loc(parameter_ptrs), n_params, & learning_rate_value, momentum_value, dampening_value, & weight_decay_value, logical(nesterov_value, c_bool)) end subroutine torch_optim_SGD !> Create an Adam optimizer subroutine torch_optim_Adam(optim, parameters, learning_rate, beta_1, beta_2, & eps, weight_decay, amsgrad) use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_loc, c_bool use, intrinsic :: iso_fortran_env, only : real64 type(torch_optim), intent(out) :: optim !! Optimizer we are creating type(torch_tensor), intent(in), dimension(:) :: parameters !! Array of parameter tensors real(kind=real64), optional, intent(in) :: learning_rate !! learning rate for the optimization algorithm (default: 0.001) real(kind=real64), optional, intent(in) :: beta_1 !! beta 1 for the optimization algorithm (default: 0.9) real(kind=real64), optional, intent(in) :: beta_2 !! beta 2 for the optimization algorithm (default: 0.999) real(kind=real64), optional, intent(in) :: eps !! eps for the optimization algorithm (default: 1.0e-8) real(kind=real64), optional, intent(in) :: weight_decay !! weight_decay for the optimization algorithm (default: 0.0) logical, optional, intent(in) :: amsgrad !! enable AMSGrad variant (default: .false.) real(kind=real64) :: learning_rate_value !! Resolved learning_rate value to be passed to the C interface real(kind=real64) :: beta_1_value !! Resolved beta_1 value to be passed to the C interface real(kind=real64) :: beta_2_value !! Resolved beta 2 value to be passed to the C interface real(kind=real64) :: eps_value !! Resolved eps value to be passed to the C interface real(kind=real64) :: weight_decay_value !! Resolved weight_decay value to be passed to the C interface logical :: amsgrad_value !! Resolved amsgrad value to be passed to the C interface integer(ftorch_int) :: i integer(c_int) :: n_params type(c_ptr), dimension(size(parameters)), target :: parameter_ptrs interface function torch_optim_Adam_c(parameters_c, n_params_c, learning_rate_c, & beta_1_c, beta_2_c, eps_c, weight_decay_c, amsgrad_c) & result(optim_c) bind(c, name = 'torch_optim_Adam') use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_bool implicit none type(c_ptr), value, intent(in) :: parameters_c integer(c_int), value, intent(in) :: n_params_c real(c_double), value, intent(in) :: learning_rate_c, beta_1_c, beta_2_c, & eps_c, weight_decay_c logical(c_bool), value, intent(in) :: amsgrad_c type(c_ptr) :: optim_c end function torch_optim_Adam_c end interface n_params = size(parameters) if (.not. present(learning_rate)) then learning_rate_value = 0.001_real64 else learning_rate_value = learning_rate end if if (.not. present(beta_1)) then beta_1_value = 0.9_real64 else beta_1_value = beta_1 end if if (.not. present(beta_2)) then beta_2_value = 0.999_real64 else beta_2_value = beta_2 end if if (.not. present(eps)) then eps_value = 1.0e-8_real64 else eps_value = eps end if if (.not. present(weight_decay)) then weight_decay_value = 0.0_real64 else weight_decay_value = weight_decay end if if (.not. present(amsgrad)) then amsgrad_value = .false. else amsgrad_value = amsgrad end if ! Assign array of pointers to the parameters do i = 1, n_params parameter_ptrs(i) = parameters(i)%p end do optim%p = torch_optim_Adam_c(c_loc(parameter_ptrs), n_params, learning_rate_value, & beta_1_value, beta_2_value, eps_value, & weight_decay_value, logical(amsgrad_value, c_bool)) end subroutine torch_optim_Adam !> Create an AdamW optimizer subroutine torch_optim_AdamW(optim, parameters, learning_rate, beta_1, beta_2, & eps, weight_decay, amsgrad) use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_loc, c_bool use, intrinsic :: iso_fortran_env, only : real64 type(torch_optim), intent(out) :: optim !! Optimizer we are creating type(torch_tensor), intent(in), dimension(:) :: parameters !! Array of parameter tensors real(kind=real64), optional, intent(in) :: learning_rate !! learning rate for the optimization algorithm (default: 0.001) real(kind=real64), optional, intent(in) :: beta_1 !! beta 1 for the optimization algorithm (default: 0.9) real(kind=real64), optional, intent(in) :: beta_2 !! beta 2 for the optimization algorithm (default: 0.999) real(kind=real64), optional, intent(in) :: eps !! eps for the optimization algorithm (default: 1.0e-8) real(kind=real64), optional, intent(in) :: weight_decay !! weight_decay for the optimization algorithm (default: 0.01) logical, optional, intent(in) :: amsgrad !! enable AMSGrad variant (default: .false.) real(kind=real64) :: learning_rate_value !! Resolved learning_rate value to be passed to the C interface real(kind=real64) :: beta_1_value !! Resolved beta_1 value to be passed to the C interface real(kind=real64) :: beta_2_value !! Resolved beta 2 value to be passed to the C interface real(kind=real64) :: eps_value !! Resolved eps value to be passed to the C interface real(kind=real64) :: weight_decay_value !! Resolved weight_decay value to be passed to the C interface logical :: amsgrad_value !! Resolved amsgrad value to be passed to the C interface integer(ftorch_int) :: i integer(c_int) :: n_params type(c_ptr), dimension(size(parameters)), target :: parameter_ptrs interface function torch_optim_AdamW_c(parameters_c, n_params_c, learning_rate_c, & beta_1_c, beta_2_c, weight_decay_c, eps_c, amsgrad_c) & result(optim_c) bind(c, name = 'torch_optim_AdamW') use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_double, c_bool implicit none type(c_ptr), value, intent(in) :: parameters_c integer(c_int), value, intent(in) :: n_params_c real(c_double), value, intent(in) :: learning_rate_c, beta_1_c, beta_2_c, & eps_c, weight_decay_c logical(c_bool), value, intent(in) :: amsgrad_c type(c_ptr) :: optim_c end function torch_optim_AdamW_c end interface n_params = size(parameters) if (.not. present(learning_rate)) then learning_rate_value = 0.001_real64 else learning_rate_value = learning_rate end if if (.not. present(beta_1)) then beta_1_value = 0.9_real64 else beta_1_value = beta_1 end if if (.not. present(beta_2)) then beta_2_value = 0.999_real64 else beta_2_value = beta_2 end if if (.not. present(eps)) then eps_value = 1.0e-8_real64 else eps_value = eps end if if (.not. present(weight_decay)) then weight_decay_value = 0.01_real64 ! Different default for AdamW to Adam! else weight_decay_value = weight_decay end if if (.not. present(amsgrad)) then amsgrad_value = .false. else amsgrad_value = amsgrad end if ! Assign array of pointers to the parameters do i = 1, n_params parameter_ptrs(i) = parameters(i)%p end do optim%p = torch_optim_AdamW_c(c_loc(parameter_ptrs), n_params, learning_rate_value, & beta_1_value, beta_2_value, eps_value, & weight_decay_value, logical(amsgrad_value, c_bool)) end subroutine torch_optim_AdamW end module ftorch_optim