diff --git a/src/states/states_elec.F90 b/src/states/states_elec.F90 index 1240beae911be70f3c2496d6423eac06ca7046a4..1e7eb71cfe3995c9984e66bf5656704671213217 100644 --- a/src/states/states_elec.F90 +++ b/src/states/states_elec.F90 @@ -86,6 +86,7 @@ module states_elec_oct_m states_elec_calc_quantities, & state_is_local, & state_kpt_is_local, & + kpt_is_local, & states_elec_choose_kpoints, & states_elec_distribute_nodes, & states_elec_wfns_memory, & @@ -2322,12 +2323,25 @@ contains PUSH_SUB(state_kpt_is_local) - state_kpt_is_local = ist >= st%st_start .and. ist <= st%st_end .and. & - ik >= st%d%kpt%start .and. ik <= st%d%kpt%end + state_kpt_is_local = state_is_local(st, ist) .and. kpt_is_local(st, ik) POP_SUB(state_kpt_is_local) end function state_kpt_is_local + ! --------------------------------------------------------- + !> check whether a given kpoint (ik) is on the local node + ! + logical function kpt_is_local(st, ik) + type(states_elec_t), intent(in) :: st + integer, intent(in) :: ik + + PUSH_SUB(kpt_is_local) + + kpt_is_local = ik >= st%d%kpt%start .and. ik <= st%d%kpt%end + + POP_SUB(kpt_is_local) + end function kpt_is_local + ! --------------------------------------------------------- !> return the memory usage of a states_elec_t object diff --git a/src/states/states_elec_calc_inc.F90 b/src/states/states_elec_calc_inc.F90 index 7124ba8a03b2a82875660f6ee9e78eb3a9a5afc4..b162768bbe665eb04b82f6c22c150fd1bf33f785 100644 --- a/src/states/states_elec_calc_inc.F90 +++ b/src/states/states_elec_calc_inc.F90 @@ -1705,186 +1705,182 @@ subroutine X(states_elec_rrqr_decomposition)(st, namespace, mesh, nst, root, ik, do_serial = .true. end if - if (.not. do_serial) then - !TODO: Implement the spinor case - ASSERT(st%d%dim == 1) + if (kpt_is_local(st, ik)) then + if (.not. do_serial) then + !TODO: Implement the spinor case + ASSERT(st%d%dim == 1) - call states_elec_parallel_blacs_blocksize(st, namespace, mesh, psi_block, total_np) + call states_elec_parallel_blacs_blocksize(st, namespace, mesh, psi_block, total_np) - ! allocate local part of transpose state matrix - SAFE_ALLOCATE(KSt(1:lnst,1:total_np)) - SAFE_ALLOCATE(psi(1:mesh%np, 1:st%d%dim)) + ! allocate local part of transpose state matrix + SAFE_ALLOCATE(KSt(1:lnst,1:total_np)) + SAFE_ALLOCATE(psi(1:mesh%np, 1:st%d%dim)) - ! copy states into the transpose matrix - count = 0 - do ist = st%st_start,st%st_end - count = count + 1 + ! copy states into the transpose matrix + count = 0 + do ist = st%st_start,st%st_end + count = count + 1 - call states_elec_get_state(st, mesh, ist, ik, psi) + call states_elec_get_state(st, mesh, ist, ik, psi) - ! We need to set to zero some extra parts of the array - if (st%d%dim == 1) then - psi(mesh%np + 1:psi_block(1), 1:st%d%dim) = M_ZERO - else - psi(mesh%np + 1:mesh%np_part, 1:st%d%dim) = M_ZERO - end if + ! We need to set to zero some extra parts of the array + if (st%d%dim == 1) then + psi(mesh%np + 1:psi_block(1), 1:st%d%dim) = M_ZERO + else + psi(mesh%np + 1:mesh%np_part, 1:st%d%dim) = M_ZERO + end if - KSt(count, 1:total_np) = psi(1:total_np, 1) - end do + KSt(count, 1:total_np) = psi(1:total_np, 1) + end do - SAFE_DEALLOCATE_A(psi) + SAFE_DEALLOCATE_A(psi) - ! DISTRIBUTE THE MATRIX ON THE PROCESS GRID - ! Initialize the descriptor array for the main matrices (ScaLAPACK) + ! DISTRIBUTE THE MATRIX ON THE PROCESS GRID + ! Initialize the descriptor array for the main matrices (ScaLAPACK) #ifdef HAVE_SCALAPACK - call descinit(psi_desc(1), nst, total_np, psi_block(2), psi_block(1), 0, 0, & - st%dom_st_proc_grid%context, lnst, blacs_info) + call descinit(psi_desc(1), nst, total_np, psi_block(2), psi_block(1), 0, 0, & + st%dom_st_proc_grid%context, lnst, blacs_info) #endif - if (blacs_info /= 0) then - write(message(1),'(a,i6)') 'descinit failed with error code: ', blacs_info - call messages_fatal(1, namespace=namespace) - end if - - nref = min(nst, total_np) - SAFE_ALLOCATE(tau(1:nref)) - tau = M_ZERO - - ! calculate the QR decomposition - SAFE_ALLOCATE(ipiv(1:total_np)) - ipiv(1:total_np) = 0 - - ! Note: lapack routine has different number of arguments depending on type -#ifdef HAVE_SCALAPACK -#ifndef R_TREAL - call pzgeqpf(nst, total_np, KSt(1,1), 1, 1, psi_desc(1), ipiv(1), tau(1), tmp, -1, tmp2, -1, blacs_info) -#else - call pdgeqpf( nst, total_np, KSt(1,1), 1, 1, psi_desc(1), ipiv(1), tau(1), tmp, -1, blacs_info) -#endif -#endif - - if (blacs_info /= 0) then - write(message(1),'(a,i6)') 'scalapack geqrf workspace query failed with error code: ', blacs_info - call messages_fatal(1, namespace=namespace) - end if + if (blacs_info /= 0) then + write(message(1),'(a,i6)') 'descinit failed with error code: ', blacs_info + call messages_fatal(1, namespace=namespace) + end if - wsize = nint(R_REAL(tmp)) - SAFE_ALLOCATE(work(1:wsize)) -#ifdef HAVE_SCALAPACK -#ifndef R_TREAL - rwsize = max(1,nint(R_REAL(tmp2))) - SAFE_ALLOCATE(rwork(1:rwsize)) - call pzgeqpf(nst, total_np, KSt(1,1), 1, 1, psi_desc(1), ipiv(1), tau(1), work(1), wsize, rwork(1), rwsize, blacs_info) - SAFE_DEALLOCATE_A(rwork) -#else - call pdgeqpf(nst, total_np, KSt(1,1), 1, 1, psi_desc(1), ipiv(1), tau(1), work(1), wsize, blacs_info) -#endif -#endif + nref = min(nst, total_np) + SAFE_ALLOCATE(tau(1:nref)) + tau = M_ZERO + + ! calculate the QR decomposition + SAFE_ALLOCATE(ipiv(1:total_np)) + ipiv(1:total_np) = 0 + + ! Note: lapack routine has different number of arguments depending on type + #ifdef HAVE_SCALAPACK + #ifndef R_TREAL + call pzgeqpf(nst, total_np, KSt(1,1), 1, 1, psi_desc(1), ipiv(1), tau(1), tmp, -1, tmp2, -1, blacs_info) + #else + call pdgeqpf( nst, total_np, KSt(1,1), 1, 1, psi_desc(1), ipiv(1), tau(1), tmp, -1, blacs_info) + #endif + #endif + + if (blacs_info /= 0) then + write(message(1),'(a,i6)') 'scalapack geqrf workspace query failed with error code: ', blacs_info + call messages_fatal(1, namespace=namespace) + end if - if (blacs_info /= 0) then - write(message(1),'(a,i6)') 'scalapack geqrf call failed with error code: ', blacs_info - call messages_fatal(1, namespace=namespace) - end if - SAFE_DEALLOCATE_A(work) + wsize = nint(R_REAL(tmp)) + SAFE_ALLOCATE(work(1:wsize)) + #ifdef HAVE_SCALAPACK + #ifndef R_TREAL + rwsize = max(1,nint(R_REAL(tmp2))) + SAFE_ALLOCATE(rwork(1:rwsize)) + call pzgeqpf(nst, total_np, KSt(1,1), 1, 1, psi_desc(1), ipiv(1), tau(1), work(1), wsize, rwork(1), rwsize, blacs_info) + SAFE_DEALLOCATE_A(rwork) + #else + call pdgeqpf(nst, total_np, KSt(1,1), 1, 1, psi_desc(1), ipiv(1), tau(1), work(1), wsize, blacs_info) + #endif + #endif + + if (blacs_info /= 0) then + write(message(1),'(a,i6)') 'scalapack geqrf call failed with error code: ', blacs_info + call messages_fatal(1, namespace=namespace) + end if + SAFE_DEALLOCATE_A(work) - ! copy the first nst global elements of ipiv into jpvt - ! bcast is at the end of the routine -! if (mpi_world%rank == 0) then -! do ist =1,nst -! write(123,*) ipiv(ist) -! end do -! end if - jpvt(1:nst) = ipiv(1:nst) + ! copy the first nst global elements of ipiv into jpvt + ! bcast is at the end of the routine + jpvt(1:nst) = ipiv(1:nst) - else - ! first gather states into one array on the root process - ! build transpose of KS set on which RRQR is performed + else + ! first gather states into one array on the root process + ! build transpose of KS set on which RRQR is performed - ! We follow Pizzi et al., J. Phys.: Condens. Matter 32 (2020) 165902 - ! for doing the extension to spinors - if (root) then - SAFE_ALLOCATE(KSt(1:nst,1:mesh%np_global*st%d%dim)) - end if + ! We follow Pizzi et al., J. Phys.: Condens. Matter 32 (2020) 165902 + ! for doing the extension to spinors + if (root) then + SAFE_ALLOCATE(KSt(1:nst,1:mesh%np_global*st%d%dim)) + end if - ! gather states in case of domain parallelization - if (mesh%parallel_in_domains.or.st%parallel_in_states) then - SAFE_ALLOCATE(temp_state(1:mesh%np)) - SAFE_ALLOCATE(state_global(1:mesh%np_global)) + ! gather states in case of domain parallelization + if (mesh%parallel_in_domains.or.st%parallel_in_states) then + SAFE_ALLOCATE(temp_state(1:mesh%np)) + SAFE_ALLOCATE(state_global(1:mesh%np_global)) - count = 0 - do ii = 1, nst - do idim = 1, st%d%dim - !we are copying states like this: KSt(i,:) = st%psi(:,dim,i,nik) - state_global(1:mesh%np_global) = M_ZERO - sender = 0 - if (state_is_local(st,ii)) then - call states_elec_get_state(st, mesh, idim, ii, ik, temp_state) - call par_vec_gather(mesh%pv, 0, temp_state(1:mesh%np), state_global) - if (mesh%mpi_grp%rank == 0) sender = mpi_world%rank - end if - call comm_allreduce(mpi_world,sender) - call mpi_world%bcast(state_global(1), mesh%np_global, R_MPITYPE, sender) - ! keep full Kohn-Sham matrix only on root - if (root) KSt(ii, (idim-1)*mesh%np_global+1:idim*mesh%np_global) = st%occ(ii, ik)*R_CONJ(state_global(1:mesh%np_global)) + count = 0 + do ii = 1, nst + do idim = 1, st%d%dim + !we are copying states like this: KSt(i,:) = st%psi(:,dim,i,nik) + state_global(1:mesh%np_global) = M_ZERO + sender = 0 + if (state_is_local(st,ii)) then + call states_elec_get_state(st, mesh, idim, ii, ik, temp_state) + call par_vec_gather(mesh%pv, 0, temp_state(1:mesh%np), state_global) + if (mesh%mpi_grp%rank == 0) sender = mpi_world%rank + end if + call comm_allreduce(mpi_world,sender) + call mpi_world%bcast(state_global(1), mesh%np_global, R_MPITYPE, sender) + ! keep full Kohn-Sham matrix only on root + if (root) KSt(ii, (idim-1)*mesh%np_global+1:idim*mesh%np_global) = st%occ(ii, ik)*R_CONJ(state_global(1:mesh%np_global)) + end do end do - end do - SAFE_DEALLOCATE_A(state_global) - SAFE_DEALLOCATE_A(temp_state) - else - ! serial - SAFE_ALLOCATE(temp_state(1:mesh%np)) - do ii = 1, nst - do idim = 1, st%d%dim - ! this call is necessary becasue we want to have only np not np_part - call states_elec_get_state(st, mesh, idim, ii, ik, temp_state) - KSt(ii, (idim-1)*mesh%np+1:idim*mesh%np) = st%occ(ii,ik)*R_CONJ(temp_state) + SAFE_DEALLOCATE_A(state_global) + SAFE_DEALLOCATE_A(temp_state) + else + ! serial + SAFE_ALLOCATE(temp_state(1:mesh%np)) + do ii = 1, nst + do idim = 1, st%d%dim + ! this call is necessary becasue we want to have only np not np_part + call states_elec_get_state(st, mesh, idim, ii, ik, temp_state) + KSt(ii, (idim-1)*mesh%np+1:idim*mesh%np) = st%occ(ii,ik)*R_CONJ(temp_state) + end do end do - end do - SAFE_DEALLOCATE_A(temp_state) - end if + SAFE_DEALLOCATE_A(temp_state) + end if - ! now perform serial RRQR - ! dummy call to obtain dimension of work - ! Note: the lapack routine has different number of arguments depending on type - if (root) then - SAFE_ALLOCATE(work(1)) - SAFE_ALLOCATE(tau(1:nst)) - ASSERT(mesh%np_global*st%d%dim < huge(0_int32)) - ASSERT(size(jpvt) >= mesh%np_global*st%d%dim) + ! now perform serial RRQR + ! dummy call to obtain dimension of work + ! Note: the lapack routine has different number of arguments depending on type + if (root) then + SAFE_ALLOCATE(work(1)) + SAFE_ALLOCATE(tau(1:nst)) + ASSERT(mesh%np_global*st%d%dim < huge(0_int32)) + ASSERT(size(jpvt) >= mesh%np_global*st%d%dim) #ifdef R_TREAL - call dgeqp3(nst, i8_to_i4(mesh%np_global*st%d%dim), kst, nst, jpvt, tau, work, -1, info) + call dgeqp3(nst, i8_to_i4(mesh%np_global*st%d%dim), kst, nst, jpvt, tau, work, -1, info) #else - SAFE_ALLOCATE(rwork(1:2*mesh%np_global*st%d%dim)) - call zgeqp3(nst, i8_to_i4(mesh%np_global*st%d%dim), kst, nst, jpvt, tau, work, -1, rwork, info) + SAFE_ALLOCATE(rwork(1:2*mesh%np_global*st%d%dim)) + call zgeqp3(nst, i8_to_i4(mesh%np_global*st%d%dim), kst, nst, jpvt, tau, work, -1, rwork, info) #endif - if (info /= 0) then - write(message(1),'(A28,I2)') 'Illegal argument in ZGEQP3: ', info - call messages_fatal(1, namespace=namespace) - end if + if (info /= 0) then + write(message(1),'(A28,I2)') 'Illegal argument in ZGEQP3: ', info + call messages_fatal(1, namespace=namespace) + end if - wsize = int(work(1)) - SAFE_DEALLOCATE_A(work) - SAFE_ALLOCATE(work(1:wsize)) + wsize = int(work(1)) + SAFE_DEALLOCATE_A(work) + SAFE_ALLOCATE(work(1:wsize)) - jpvt(:) = 0 - tau(:) = M_ZERO - ! actual call + jpvt(:) = 0 + tau(:) = M_ZERO + ! actual call #ifdef R_TREAL - call dgeqp3(nst, i8_to_i4(mesh%np_global*st%d%dim), kst, nst, jpvt, tau, work, wsize, info) + call dgeqp3(nst, i8_to_i4(mesh%np_global*st%d%dim), kst, nst, jpvt, tau, work, wsize, info) #else - call zgeqp3(nst, i8_to_i4(mesh%np_global*st%d%dim), kst, nst, jpvt, tau, work, wsize, rwork, info) + call zgeqp3(nst, i8_to_i4(mesh%np_global*st%d%dim), kst, nst, jpvt, tau, work, wsize, rwork, info) #endif - if (info /= 0) then - write(message(1),'(A28,I2)') 'Illegal argument in ZGEQP3: ', info - call messages_fatal(1, namespace=namespace) + if (info /= 0) then + write(message(1),'(A28,I2)') 'Illegal argument in ZGEQP3: ', info + call messages_fatal(1, namespace=namespace) + end if + SAFE_DEALLOCATE_A(work) + SAFE_DEALLOCATE_A(tau) end if - SAFE_DEALLOCATE_A(work) - SAFE_DEALLOCATE_A(tau) - end if - - SAFE_DEALLOCATE_A(temp_state) - SAFE_DEALLOCATE_A(state_global) + SAFE_DEALLOCATE_A(temp_state) + SAFE_DEALLOCATE_A(state_global) + end if end if call mpi_world%bcast(jpvt(1), nst*st%d%dim, MPI_INTEGER, 0)