diff --git a/src/states/elec_matrix_elements.F90 b/src/states/elec_matrix_elements.F90 index a1c1665ff91586c0fa258a2978369b3e8161cb24..a9419d278393e6dee2a4008f7659be5998cac930 100644 --- a/src/states/elec_matrix_elements.F90 +++ b/src/states/elec_matrix_elements.F90 @@ -123,9 +123,9 @@ contains ! ----------------------------------------------------------------------------- subroutine elec_momentum_me(this, kpoints, momentum) - class(elec_matrix_elements_t), intent(in) :: this - type(kpoints_t), intent(in) :: kpoints - FLOAT, intent(out) :: momentum(:,:,:) + class(elec_matrix_elements_t), intent(in) :: this + type(kpoints_t), intent(in) :: kpoints + FLOAT, intent(inout) :: momentum(:,:,:) PUSH_SUB(elec_momentum_me) diff --git a/src/states/elec_matrix_elements_inc.F90 b/src/states/elec_matrix_elements_inc.F90 index 288b1135acb297fae743456509e86b8882437277..9db0ef728dcddc22dce992910f4fb31e8df051df 100644 --- a/src/states/elec_matrix_elements_inc.F90 +++ b/src/states/elec_matrix_elements_inc.F90 @@ -6,29 +6,29 @@ !!\f] ! --------------------------------------------------------- subroutine X(elec_momentum_me)(this, kpoints, momentum) - type(elec_matrix_elements_t), intent(in) :: this - type(kpoints_t), intent(in) :: kpoints - FLOAT, intent(out) :: momentum(:,:,:) + type(elec_matrix_elements_t), intent(in) :: this + type(kpoints_t), intent(in) :: kpoints + FLOAT, intent(inout) :: momentum(this%space%dim, this%states%st_start:this%states%st_end, this%states%d%kpt%start:this%states%d%kpt%end) integer :: idim, ist, ik, idir CMPLX :: expect_val_p R_TYPE, allocatable :: psi(:, :), grad(:,:,:) FLOAT :: kpoint(this%space%dim) -#if defined(HAVE_MPI) - integer :: tmp - FLOAT, allocatable :: lmomentum(:), gmomentum(:) -#endif - FLOAT, allocatable :: lmom(:, :, :) - integer :: k_start, k_end, k_n, ndim, nst + integer :: st_start, st_end, k_start, k_end PUSH_SUB(X(elec_momentum_me)) + st_start = this%states%st_start + st_end = this%states%st_end + k_start = this%states%d%kpt%start + k_end = this%states%d%kpt%end + SAFE_ALLOCATE(psi(1:this%grid%np_part, 1:this%states%d%dim)) SAFE_ALLOCATE(grad(1:this%grid%np, 1:this%space%dim, 1:this%states%d%dim)) - nst = this%states%nst - do ik = this%states%d%kpt%start, this%states%d%kpt%end - do ist = this%states%st_start, this%states%st_end + do ik = k_start, k_end + kpoint(:) = kpoints%get_point(this%states%d%get_kpoint_index(ik)) + do ist = st_start, st_end call states_elec_get_state(this%states, this%grid, ist, ik, psi) @@ -46,57 +46,22 @@ subroutine X(elec_momentum_me)(this, kpoints, momentum) ! -i prefactor of p = -i \nabla if (states_are_real(this%states)) then momentum(idir, ist, ik) = real(expect_val_p) + ASSERT(abs(imag(expect_val_p)) < 1E-10_real64) else momentum(idir, ist, ik) = real(-M_zI * expect_val_p) + ASSERT(abs(real(expect_val_p)) < 1E-10_real64) end if end do ! have to add the momentum vector in the case of periodic systems, ! since psi contains only u_k - kpoint(:) = kpoints%get_point(this%states%d%get_kpoint_index(ik)) do idir = 1, this%space%periodic_dim momentum(idir, ist, ik) = momentum(idir, ist, ik) + kpoint(idir) end do end do - - ! Exchange momenta in the states parallel case. -#if defined(HAVE_MPI) - if (this%states%parallel_in_states) then - SAFE_ALLOCATE(lmomentum(1:this%states%lnst)) - SAFE_ALLOCATE(gmomentum(1:nst)) - - do idir = 1, this%states%d%dim - lmomentum(1:this%states%lnst) = momentum(idir, this%states%st_start:this%states%st_end, ik) - call lmpi_gen_allgatherv(this%states%lnst, lmomentum, tmp, gmomentum, this%states%mpi_grp) - momentum(idir, 1:nst, ik) = gmomentum(1:nst) - end do - - SAFE_DEALLOCATE_A(lmomentum) - SAFE_DEALLOCATE_A(gmomentum) - end if -#endif end do - ! Handle kpoints parallelization - if (this%states%d%kpt%parallel) then - k_start = this%states%d%kpt%start - k_end = this%states%d%kpt%end - k_n = this%states%d%kpt%nlocal - ndim = ubound(momentum, dim = 1) - - ASSERT(.not. this%states%parallel_in_states) - - SAFE_ALLOCATE(lmom(1:ndim, 1:nst, 1:k_n)) - - lmom(1:ndim, 1:nst, 1:k_n) = momentum(1:ndim, 1:nst, k_start:k_end) - - call this%states%d%kpt%mpi_grp%allgatherv(lmom, ndim * nst * k_n, MPI_FLOAT, momentum, & - this%states%d%kpt%num(:) * nst * ndim, (this%states%d%kpt%range(1, :) - 1)*nst*ndim, MPI_FLOAT) - - SAFE_DEALLOCATE_A(lmom) - end if - SAFE_DEALLOCATE_A(psi) SAFE_DEALLOCATE_A(grad) diff --git a/src/td/td_calc.F90 b/src/td/td_calc.F90 index 449bbebb615adb7decd6fc1dee18158ef20fdc2a..3a05a080e03becbd11d5f2e1788ad99c9a73040e 100644 --- a/src/td/td_calc.F90 +++ b/src/td/td_calc.F90 @@ -181,20 +181,29 @@ contains FLOAT, intent(out) :: vel(:) FLOAT, allocatable :: momentum(:,:,:) - integer :: st_start, st_end + FLOAT :: weight + integer :: st_start, st_end, k_start, k_end, ist, ik PUSH_SUB(td_calc_tvel) st_start = elec_me%states%st_start st_end = elec_me%states%st_end + k_start = elec_me%states%d%kpt%start + k_end = elec_me%states%d%kpt%end - SAFE_ALLOCATE(momentum(1:elec_me%space%dim, st_start:st_end, elec_me%states%d%kpt%start:elec_me%states%d%kpt%end)) + SAFE_ALLOCATE_SOURCE(momentum(1:elec_me%space%dim, st_start:st_end, k_start:k_end), 0.0_real64) call elec_me%momentum_me(kpoints, momentum) - momentum(1:elec_me%space%dim, st_start:st_end, 1) = & - sum(momentum(1:elec_me%space%dim, st_start:st_end, elec_me%states%d%kpt%start:elec_me%states%d%kpt%end), 3) - momentum(1:elec_me%space%dim, 1, 1) = sum(momentum(1:elec_me%space%dim, st_start:st_end, 1), 2) - vel = momentum(:, 1, 1) + do ik = k_start, k_end + do ist = st_start, st_end + weight = elec_me%states%kweights(ik) * elec_me%states%occ(ist, ik) + if (abs(weight) <= M_EPSILON) cycle + vel(:) = vel(:) + weight * momentum(:, ist, ik) + end do + end do + + call comm_allreduce(elec_me%states%st_kpt_mpi_grp, vel) + call comm_allreduce(elec_me%states%mpi_grp, vel) SAFE_DEALLOCATE_A(momentum) diff --git a/src/td/td_write.F90 b/src/td/td_write.F90 index 2240fe15ed9a37d75fb4536bfb7854d34cc46d66..7d18b2d3825ae8d2596904e83b9a359ad0d3ecf6 100644 --- a/src/td/td_write.F90 +++ b/src/td/td_write.F90 @@ -1835,11 +1835,9 @@ contains character(len=7) :: aux FLOAT :: vel(elec_me%space%dim) - if (.not. mpi_grp_is_root(mpi_world)) return ! only first node outputs - PUSH_SUB(td_write_vel) - if (iter == 0) then + if (iter == 0 .and. mpi_grp_is_root(mpi_world)) then call td_write_print_header_init(out_vel) ! first line -> column names @@ -1862,10 +1860,12 @@ contains call td_calc_tvel(elec_me, kpoints, vel) - call write_iter_start(out_vel) - vel = units_from_atomic(units_out%velocity, vel) - call write_iter_double(out_vel, vel, elec_me%space%dim) - call write_iter_nl(out_vel) + if (mpi_grp_is_root(mpi_world)) then + call write_iter_start(out_vel) + vel = units_from_atomic(units_out%velocity, vel) + call write_iter_double(out_vel, vel, elec_me%space%dim) + call write_iter_nl(out_vel) + end if POP_SUB(td_write_vel) end subroutine td_write_vel