! Subroutine to fit a hyperplane to a set of data points having error
! bars in each dimension.  Based on Deming's formulation of least-squares
! problem.  Makes intercept implicit by shifting data to weighted-average
! point.  (This improves the condition of the equations and reduces their
! dim by 1.)  Uses fixed-point iteration to find an approximate solution,
! then Newton-Raphson iteration to converge.
!
! Author: Robert K. Moniot, Fordham University
! Date:   December 2005
!
!  Equation of hyperplane:
!     a(1)*Y(1) + a(2)*Y(2) + ... + a(m-1)*Y(m-1) + a(m) = Y(m)
!  where Y is a point on h.plane, and m = number of coordinates in the space.
!
! Copyright (c) 2005 by Robert K. Moniot.
!
! Permission is hereby granted, free of charge, to any person
! obtaining a copy of this software and associated documentation
! files (the "Software"), to deal in the Software without
! restriction, including without limitation the rights to use,
! copy, modify, merge, publish, distribute, sublicense, and/or
! sell copies of the Software, and to permit persons to whom the
! Software is furnished to do so, subject to the following
! conditions:

! The above copyright notice and this permission notice shall be
! included in all copies or substantial portions of the
! Software.

! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY
! KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
! WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
! PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
! COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
! LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
! OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
! SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
!
! Acknowledgement: the above permission notice is known as the X
! Consortium license.

module hyperplane_fit

  public :: hyperplane
  private :: FofA
contains

  subroutine hyperplane( npts, mdims, Y, sigma_Y, a, chisq, status, &
       V, conv_crit, max_iters, fixedpt_conv_crit, fixedpt_iters, &
       verbosity, exact_a )
  use precision_module

  implicit none
    integer, intent(in) :: npts                    ! number of data points
    integer, intent(in) :: mdims                   ! number of dimensions
    real(kind=DP), intent(in), dimension(:,:) :: Y ! the given data points
    real(kind=DP), intent(in), dimension(:,:) :: sigma_Y ! given error bars
    real(kind=DP), intent(out), dimension(:) :: a  ! coeffs of hyperplane eqn
    real(kind=DP), intent(out) :: chisq            ! variance
    integer, intent(out) :: status                 ! termination status
    real(kind=DP), optional, intent(out), dimension(:,:) :: V ! covariance of a
    real(kind=DP), optional, intent(in) :: conv_crit     ! convergence criterion    
    integer, optional, intent(in) :: max_iters   ! limit to number of iterations   
    real(kind=DP), optional, intent(in) :: fixedpt_conv_crit     ! conv crit for fixed-pt  
    integer, optional, intent(in) :: fixedpt_iters   ! limit to fixed-pt iterations
    integer, optional, intent(in) :: verbosity     ! amount of debugging output
    real(kind=DP), optional, intent(in), dimension(:) :: exact_a ! true soln for use in conv tests

    real(kind=DP) :: norm_delta_a, prev_norm_delta_a ! for detecting when accuracy limit reached
    real(kind=DP), allocatable, dimension(:) :: prev_a
    real(kind=DP), allocatable, dimension(:) :: W    ! weights
    real(kind=DP), allocatable, dimension(:) :: Ybar ! mean point
    real(kind=DP), allocatable, dimension(:,:) :: u  ! sigma^2
    real(kind=DP), allocatable, dimension(:,:) :: yp ! Y - Ybar
    real(kind=DP), allocatable, dimension(:) :: fval ! f(Y) values
    real(kind=DP), allocatable, dimension(:,:) :: z, zp  ! adjustments
    real(kind=DP), allocatable, dimension(:,:) :: Mat  ! matrix to solve for a
    real(kind=DP), allocatable, dimension(:) :: RHS ! const vec in eqn for a
    integer, allocatable, dimension(:) :: ipiv  ! pivot indices
    real(kind=DP), allocatable, dimension(:) :: delta_a, old_delta_a ! for Aitken
    real(kind=DP) :: sumW

    real(kind=DP), external:: ddot ! blas routine

    integer :: iter
    integer :: i, j, k
    integer :: lapack_info             ! return code from lapack routines
       ! Define local variables to take on optionally provided value or default
    integer :: max_iters_val
    real(kind=DP) :: conv_crit_val
    integer :: verbosity_val
    integer :: fixedpt_iters_val
    real(kind=DP) :: fixedpt_conv_crit_val



! For calculating covariance matrix
    integer :: h,l
    real(kind=DP), allocatable, dimension(:) :: zbar  ! ave of z
    real(kind=DP), allocatable, dimension(:,:) :: mu

       ! perform sanity check on the given array sizes
    if( size(Y,1) < npts .or. size(Y,2) < mdims .or. &
         size(sigma_Y,1) < npts .or. size(sigma_Y,2) < mdims .or. &
         size(a) < mdims ) then
       print *, "Error:  subroutine hyperplane: data array must be size (", &
            npts,",",mdims,")"
       print *, "and parameter array must be size (",mdims,")"
       stop
    end if

    ! Set default value of iteration limit if not given
    if( present( max_iters ) ) then
       max_iters_val = max_iters
    else
       max_iters_val = precision(a)  ! assumes get > 1 digit per iteration
    end if

    ! Set default value of convergence criterion if not given
    if( present( conv_crit ) ) then
       conv_crit_val = conv_crit
    else
       conv_crit_val = 10*epsilon(a)
    end if

    ! Note that fixed-point iteration will stop whenever fixedpt_iters
    ! is exceeded or fixedpt_conv_crit is reached, whichever occurs first.

    ! Set default value of fixed-point iteration limit if not given
    if( present( fixedpt_iters ) ) then
       fixedpt_iters_val = fixedpt_iters
    else
       fixedpt_iters_val = 2  ! usually 2 iters gets a good guess
    end if

    ! Set default value of fixed-point convergence crit if not given
    if( present( fixedpt_conv_crit ) ) then
       fixedpt_conv_crit_val = fixedpt_conv_crit
    else
       fixedpt_conv_crit_val = 1e-4  ! plenty good start for Newton
    end if

    ! Set default value of verbosity level if not given
    if( present( verbosity ) ) then
       verbosity_val = verbosity
    else
       verbosity_val = 0
    end if

    if(verbosity_val >= 1) then
       print *, "Hyperplane fit by newton-raphson iteration"
       print *, "Verbosity level", verbosity_val
       write(unit=*,fmt="(a,i10)")    "Using max iterations  =", max_iters_val
       write(unit=*,fmt="(a,es10.2)") "Using convergence crit=", conv_crit_val
       write(unit=*,fmt="(a,i7)")    "deg of freedom=",(mdims*(npts-1))
    end if

    allocate( W(1:npts), fval(1:npts), Ybar(1:mdims), prev_a(1:mdims) )
    allocate( Mat(1:mdims-1,1:mdims-1), RHS(1:mdims-1), ipiv(1:mdims-1) )
    allocate( u(1:npts,1:mdims), yp(1:npts,1:mdims), z(1:npts,1:mdims))
    allocate(delta_a(1:mdims-1), old_delta_a(1:mdims-1))

    u(1:npts,1:mdims) = sigma_Y(1:npts,1:mdims)**2

    if( any(u(:,mdims) == 0.0) ) then
       a(1:mdims) = 1.0         ! safe choice when Y is exact
       prev_a(1:mdims) = 0.0    ! make it different
    else
       a(1:mdims) = 0.0         ! choice that makes 1st iter the std WLS fit
        prev_a(1:mdims) = 1.0
     end if

!
! First we do a few fixed-point iterations to obtain a good starting
! point for the newton-raphson method.  Unlike newton-raphson, the
! fixed-point method is good at converging from a bad starting guess.
!

    status = 1  ! start with status = failed-to-converge
    do iter=1,fixedpt_iters_val              ! begin fixed-point iteration
       do i=1,npts
          W(i) = sum( a(1:mdims-1)**2*u(i,1:mdims-1) ) + u(i,mdims)
       end do
       if( any( W == 0.0 ) ) then
          print *, "Error: subroutine hyperplane: infinite weight in point", &
               minloc(W), " (fixed-point iteration ", iter, ")"
          status = 2            ! failure
          exit
       end if
       W = 1.0/W                ! divide now that we know it is safe
       sumW = sum(W)
       do j=1,mdims
          Ybar(j) = sum(W*Y(1:npts,j))/sumW
          yp(1:npts,j) = Y(1:npts,j) - Ybar(j)
       end do
       ! Compute the residuals
       do i=1,npts
          fval(i) = sum( a(1:mdims-1)*yp(i,1:mdims-1) ) - yp(i,mdims)
       end do
       ! Compute z = adjusted points - Ybar (y' in Williamson's notation)
       z = yp - spread( W*fval, 2, mdims)*spread(a,1,npts)*u
       do j=1,mdims-1
          do k=1,mdims-1
             Mat(j,k) = sum(W*z(:,j)*yp(:,k))
          end do
       end do
       do j=1,mdims-1
          RHS(j) = sum(W*z(:,j)*yp(:,mdims))
       end do
       if( verbosity_val >= 4 ) then
          print *, "Ybar="
          write(unit=*,fmt=*) Ybar
          print *, "     W         fval      y'"
          do i=1,npts
             write(unit=*,fmt=*) W(i), fval(i), yp(i,:)
          end do
       end if
       chisq = sum(W*fval**2)
       if( verbosity_val >= 1 ) then
          write(unit=*,fmt="(a,f13.4)") "Variance      =", chisq
       end if
       if( verbosity_val >= 3 ) then
          write(unit=*,fmt=*) "Mat and RHS:"
          do j=1,mdims-1
             write(unit=*,fmt=*) real(Mat(j,:),SP), real(RHS(j),SP)
          end do
       end if

       ! Solve the system Mat*a = RHS
       call dgesv( mdims-1, 1, Mat, mdims-1, ipiv, RHS, mdims-1, lapack_info )

       if( lapack_info /= 0 ) then
          print *, "Failed to solve system of equations for coeffs"
          print *, "DGESV INFO=", lapack_info
          exit
       end if

       a(1:mdims-1) = RHS
       a(mdims) = Ybar(mdims) - sum( a(1:mdims-1)*Ybar(1:mdims-1) )
       old_delta_a = delta_a
       delta_a = prev_a(1:mdims-1)-a(1:mdims-1)

       if(verbosity_val >= 1) then
          write(unit=*,fmt=*) "Fixed-point iter=", iter, " a="
          write(unit=*,fmt="(4f20.16)")  a
       end if
       if(verbosity_val >= 2) then
          if( present(exact_a) ) then
             write(unit=*,fmt=*) "error="
             write(unit=*,fmt="(4e11.3)") a(1:mdims-1)-exact_a(1:mdims-1)
          end if
          write(unit=*,fmt=*) "delta a="
          write(unit=*,fmt="(4e11.3)") delta_a(1:mdims-1)
          if( iter > 1 ) then
             write(unit=*,fmt=*) "delta a/prev delta a="
             write(unit=*,fmt="(4e11.3)") delta_a(1:mdims-1)/old_delta_a(1:mdims-1)
          end if
       end if


       ! test for convergence
       if( all(abs(delta_a) <= fixedpt_conv_crit_val*abs(a(1:mdims-1))) ) then
          status = 0  ! successful convergence
          exit
       end if
       prev_a = a
    end do                      ! end of fixed-point iteration

    if(verbosity_val >= 1 .and. status /= 0) then
       write(unit=*,fmt=*) &
         "WARNING: fixed-point iteration failed to converge to accuracy of", &
         fixedpt_conv_crit_val, " in ", fixedpt_iters_val, " iterations"
    end if

!
! Now use newton-raphson to converge quickly to high accuracy.
!
    deallocate( z )
    allocate( mu(1:npts,1:mdims-1) )
    allocate( z(1:npts,1:mdims-1), zp(1:npts,1:mdims-1), zbar(1:mdims-1) )

    status = 1  ! start with status = failed-to-converge
    norm_delta_a = 0.0

    do iter=1,max_iters_val      ! begin newton-raphson iteration
       do i=1,npts
          W(i) = sum( a(1:mdims-1)**2*u(i,1:mdims-1) ) + u(i,mdims)
       end do
       if( any( W == 0.0 ) ) then
          print *, "Error: subroutine hyperplane: infinite weight in point", &
               minloc(W), " (newton iteration ", iter, ")"
          status = 2            ! failure
          exit
       end if
       W = 1.0/W                ! divide now that we know it is safe
       sumW = sum(W)

       ! Calculate the intercept based on the new a and new weights
       do j=1,mdims
          Ybar(j) = sum(W*Y(1:npts,j))/sumW
          yp(1:npts,j) = Y(1:npts,j) - Ybar(j)
       end do
       a(mdims) = Ybar(mdims) - sum( a(1:mdims-1)*Ybar(1:mdims-1) )

       !  Calculate the classical residuals
       do i=1,npts
          fval(i) = sum( a(1:mdims-1)*yp(i,1:mdims-1) ) - yp(i,mdims)
       end do
       ! z = adjustment Y - y
       z = spread( W*fval, 2, mdims-1)*spread(a(1:mdims-1),1,npts)*u
       ! mu = adjusted points y
       mu = Y(:,1:mdims-1) - z
       if( verbosity_val >= 4 ) then
          print *, "Updated a="
          write(unit=*,fmt="(4f20.16)")  a
          print *, "Ybar="
          write(unit=*,fmt=*) Ybar
          print *, "     W         fval      y"
          do i=1,npts
             write(unit=*,fmt=*) W(i), fval(i), mu(i,:)
          end do
       end if

       do j=1,mdims-1
          RHS(j) = sum(W*fval*mu(:,j))
       end do
       ! Calculate average of adjustments
       do j=1,mdims-1
          zbar(j) = sum(W*z(1:npts,j))/sumW
          zp(:,j) = z(:,j) - zbar(j)
       end do
       do j=1,mdims-1
          do k=1,mdims-1
             Mat(j,k) =  sum(W*(yp(:,j) - 2*zp(:,j))*(yp(:,k) - 2*zp(:,k)))
          end do
          Mat(j,j) = Mat(j,j) - sum((W*fval)**2*u(:,j))
       end do

       chisq = sum(W*fval**2)
       if( verbosity_val >= 1 ) then
          write(unit=*,fmt="(a,f13.4)") "Variance      =", chisq
       end if
       if( verbosity_val >= 3 ) then
          write(unit=*,fmt=*) "Mat and RHS:"
          do j=1,mdims-1
             write(unit=*,fmt=*) real(Mat(j,:),DP), real(RHS(j),DP)
          end do
       end if

       ! Solve the system Mat*delta_a = RHS
       call dgesv( mdims-1, 1, Mat, mdims-1, ipiv, RHS, mdims-1, lapack_info )

       if( lapack_info /= 0 ) then
          print *, "Failed to solve system of equations for coeffs"
          print *, "DGESV INFO=", lapack_info
          exit
       end if

       old_delta_a(1:mdims-1) = delta_a(1:mdims-1)
       delta_a(1:mdims-1) = RHS
       a(1:mdims-1) = a(1:mdims-1) - delta_a(1:mdims-1)
       a(mdims) = Ybar(mdims) - sum( a(1:mdims-1)*Ybar(1:mdims-1) )

       prev_norm_delta_a = norm_delta_a
       norm_delta_a = sum(abs(delta_a))

       if(verbosity_val >= 1) then
          write(unit=*,fmt=*) "Newton-Raphson iter=", iter, " a="
          write(unit=*,fmt="(4f20.16)")  a
       end if
       if(verbosity_val >= 2) then
          if( present(exact_a) ) then
             write(unit=*,fmt=*) "error="
             write(unit=*,fmt="(4e11.3)") a(1:mdims-1)-exact_a(1:mdims-1)
          end if
          write(unit=*,fmt=*) "delta a="
          write(unit=*,fmt="(4e11.3)") delta_a(1:mdims-1)
          print *, "1-norm ||a-prev_a||=", norm_delta_a
       end if

       ! test for convergence
       if( all(abs(delta_a) <= conv_crit_val*abs(a(1:mdims-1))) ) then
          status = 0  ! successful convergence
          exit
       end if

       ! Accuracy limit is reached when norm_delta_a doesn't improve.
       ! norm_delta_a is the 1-norm of change in a.  Give it a few iters
       ! to get going.
       if( iter > 3 .and. norm_delta_a > prev_norm_delta_a ) then
          if(verbosity_val >= 1) then
             print *, "Newton limit of accuracy reached at iteration", iter
          end if
          ! If the relative change in A was not small, indicates failure
          if( norm_delta_a > 0.001*sum(abs(a(1:mdims-1))) ) then
             if(verbosity_val >= 1) then
                print *, "Change on last step not small:", norm_delta_a
             end if
             status = 2         ! then indicate failure
          end if
          exit                  ! status not set to 0 since conv_crit not met
       end if

    end do                      ! end of newton-raphson iteration



    if( present(V) .and. status < 2 ) then

       ! Calculate variances in a

       ! z = adjustment Y - y, called epsilon in new writeup
       z = spread( W*fval, 2, mdims-1)*spread(a(1:mdims-1),1,npts)*u
       ! Calculate average of adjustments and z' = z - zbar
       do j=1,mdims-1
          zbar(j) = sum(W*z(1:npts,j))/sumW
          zp(:,j) = z(:,j) - zbar(j)
       end do

       if( verbosity_val >= 3 ) then
          write(unit=*,fmt=*) " zbar="
          write(unit=*,fmt=*) real(zbar(1:mdims-1),SP)
          write(unit=*,fmt=*) " Ybar="
          write(unit=*,fmt=*) real(Ybar(1:mdims-1),SP)
       end if

       ! Mat = Jacobian matrix, called J in new writeup
       do j=1,mdims-1
          do k=1,mdims-1
             Mat(j,k) =  sum(W*(yp(:,j) - 2*zp(:,j))*(yp(:,k) - 2*zp(:,k)))
          end do
          Mat(j,j) = Mat(j,j) - sum((W*fval)**2*u(:,j))
       end do


       if( verbosity_val >= 3 ) then
          write(unit=*,fmt=*) " J="
          do j=1,mdims-1
             write(unit=*,fmt=*) real(Mat(j,1:mdims-1),SP)
          end do
       end if
       
       ! V starts out as dA/dY sigma dA/dY^T, called Q in new writeup
       do j=1,mdims-1
          do k=1,mdims-1
             V(j,k) = sum(W*(yp(:,j) - zp(:,j))*(yp(:,k) - zp(:,k))) &
                  - sum(W*zp(:,j)*zp(:,k))
          end do
          V(j,j) = V(j,j) + sum((W*fval)**2*u(:,j))
       end do

       if( verbosity_val >= 3 ) then
          write(unit=*,fmt=*) " Q="
          do j=1,mdims-1
             write(unit=*,fmt=*) real(V(j,1:mdims-1),SP)
          end do
       end if

       call dgetrf( mdims-1, mdims-1, Mat, mdims-1, ipiv, lapack_info )
       
       if( lapack_info /= 0 ) then
          print *, "Failed to factor system of equations for derivs"
          print *, "DGETRF INFO=", lapack_info
          status = 2
          return
       end if
       
       call dgetrs( "N", mdims-1, mdims-1, Mat, mdims-1, ipiv, V, mdims, lapack_info)
       if( lapack_info /= 0 ) then
          print *, "Failed to backsolve system of equations for V=J^-1*Q"
          print *, "DGETRS INFO=", lapack_info
          status = 2
          return
       end if
       
       if( verbosity_val >= 3 ) then
          write(unit=*,fmt=*) " V=J^-1*Q="
          do j=1,mdims-1
             write(unit=*,fmt=*) real(V(j,1:mdims-1),SP)
          end do
       end if
       V(1:mdims-1,1:mdims-1) = transpose(V(1:mdims-1,1:mdims-1))
       call dgetrs( "N", mdims-1, mdims-1, Mat, mdims-1, ipiv, V, mdims, lapack_info)
       if( lapack_info /= 0 ) then
          print *, "Failed to backsolve system of equations for V=J^-1*V'"
          print *, "DGETRS INFO=", lapack_info
          status = 2
          return
       end if       
       if( verbosity_val >= 3 ) then
          write(unit=*,fmt=*) " V=J^-1*V="
          do j=1,mdims-1
             write(unit=*,fmt=*) real(V(j,1:mdims-1),SP)
          end do
       end if


!  Calculate covariances for row & col = mdim
       ! first calculate J^-1*zbar - (ybar-2zbar)*V which is cov(a,b) and a term in var(b)
       V(1:mdims-1,mdims) = zbar(1:mdims-1)
       RHS = 2.0_DP*zbar(1:mdims-1) - Ybar(1:mdims-1)
       call dgetrs( "N", mdims-1, 1, Mat, mdims-1, ipiv, V(1:mdims-1,mdims), mdims, lapack_info)
       if( verbosity_val >= 3 ) then
          write(unit=*,fmt=*) " J^-1*zbar="
          write(unit=*,fmt=*) real(V(1:mdims-1,mdims),SP)
          write(unit=*,fmt=*) " 2*zbar - Ybar="
          write(unit=*,fmt=*) real(RHS(1:mdims-1),SP)
       end if
       ! use delta_a as temporary to hold J^-1*zbar to get factor of 2 in var(b)
       delta_a(1:mdims-1) = V(1:mdims-1,mdims)
       call dgemv("N", mdims-1,mdims-1, 1.0_DP, V,mdims, &
            RHS,1, 1.0_DP, V(1:mdims-1,mdims),1)
       if( verbosity_val >= 3 ) then
          write(unit=*,fmt=*) " cov(a,b) ="
          write(unit=*,fmt=*) real(V(1:mdims-1,mdims),SP)
       end if
       V(mdims,1:mdims-1) = V(1:mdims-1,mdims)
       ! now set delta_a = 2*J^-1*zbar - (ybar-2zbar)*V
       delta_a(1:mdims-1) = delta_a(1:mdims-1) + V(1:mdims-1,mdims)
       ! next multiply by ybar-2zbar and add 1/sum(W_i) to get var(b)
       V(mdims,mdims) = ddot(mdims-1,delta_a(1:mdims-1),1,RHS,1)
       if( verbosity_val >= 3 ) then
          write(unit=*,fmt=*) " cov(a,b)'*(2*zbar - Ybar)="
          write(unit=*,fmt=*) real(V(mdims,mdims),SP)
       end if
       V(mdims,mdims) = V(mdims,mdims)  + 1.0_DP/sumW
    end if  ! present(V)

    deallocate(zbar,mu)
    deallocate( W, fval, Ybar, prev_a )
    deallocate( Mat, RHS, ipiv )
    deallocate( u, yp, z)
    deallocate(delta_a,old_delta_a)

  end subroutine hyperplane
end module hyperplane_fit