import copy
import cmath
import h5py
import math
import numpy
import scipy.linalg
import sys
import time
from pauxy.walkers.multi_ghf import MultiGHFWalker
from pauxy.walkers.single_det import SingleDetWalker
from pauxy.walkers.multi_det import MultiDetWalker
from pauxy.walkers.thermal import ThermalWalker
from pauxy.walkers.stack import FieldConfig
from pauxy.qmc.comm import FakeComm
from pauxy.utils.io import get_input_value
from pauxy.utils.misc import update_stack
[docs]class Walkers(object):
"""Container for groups of walkers which make up a wavefunction.
Parameters
----------
system : object
System object.
trial : object
Trial wavefunction object.
nwalkers : int
Number of walkers to initialise.
nprop_tot : int
Total number of propagators to store for back propagation + itcf.
nbp : int
Number of back propagation steps.
"""
def __init__(self, walker_opts, system, trial, qmc, verbose=False,
comm=None, nprop_tot=None, nbp=None):
self.nwalkers = qmc.nwalkers
self.ntot_walkers = qmc.ntot_walkers
self.write_freq = walker_opts.get('write_freq', 0)
self.write_file = walker_opts.get('write_file', 'restart.h5')
self.read_file = walker_opts.get('read_file', None)
if comm is None:
rank = 0
else:
rank = comm.rank
if verbose:
print("# Setting up wavefunction object.")
if trial.name == 'MultiSlater':
self.walker_type = 'MSD'
# TODO: FDM FIXTHIS
if trial.ndets == 1:
if verbose:
print("# Usinge single det walker with msd wavefunction.")
self.walker_type = 'SD'
self.walkers = [SingleDetWalker(walker_opts, system, trial,
index=w, nprop_tot=nprop_tot,
nbp=nbp)
for w in range(qmc.nwalkers)]
else:
self.walkers = [
MultiDetWalker(walker_opts, system, trial,
verbose=(verbose and w == 0))
for w in range(qmc.nwalkers)
]
self.buff_size = self.walkers[0].buff_size
if nbp is not None:
self.buff_size += self.walkers[0].field_configs.buff_size
self.walker_buffer = numpy.zeros(self.buff_size,
dtype=numpy.complex128)
elif trial.name == 'thermal':
self.walker_type = 'thermal'
self.walkers = [ThermalWalker(walker_opts, system, trial, verbose and w==0)
for w in range(qmc.nwalkers)]
self.buff_size = self.walkers[0].buff_size + self.walkers[0].stack.buff_size
self.walker_buffer = numpy.zeros(self.buff_size,
dtype=numpy.complex128)
stack_size = self.walkers[0].stack_size
if system.name == "Hubbard":
if stack_size % qmc.nstblz != 0 or qmc.nstblz < stack_size:
if verbose:
print("# Stabilisation frequency is not commensurate "
"with stack size.")
print("# Determining a better value.")
if qmc.nstblz < stack_size:
qmc.nstblz = stack_size
if verbose:
print("# Updated stabilization frequency: "
" {}".format(qmc.nstblz))
else:
qmc.nstblz = update_stack(qmc.nstblz, stack_size,
name="nstblz", verbose=verbose)
else:
self.walker_type = 'SD'
self.walkers = [SingleDetWalker(walker_opts, system, trial,
index=w, nprop_tot=nprop_tot,
nbp=nbp)
for w in range(qmc.nwalkers)]
self.buff_size = self.walkers[0].buff_size
if nbp is not None:
if verbose:
print("# Performing back propagation.")
print("# Number of steps in imaginary time: {:}.".format(nbp))
self.buff_size += self.walkers[0].field_configs.buff_size
self.walker_buffer = numpy.zeros(self.buff_size,
dtype=numpy.complex128)
if system.name == "Generic" or system.name == "UEG":
dtype = complex
else:
dtype = int
self.pcont_method = get_input_value(walker_opts, 'population_control',
default='comb')
self.min_weight = walker_opts.get('min_weight', 0.1)
self.max_weight = walker_opts.get('max_weight', 4.0)
if verbose:
print("# Using {} population control "
"algorithm.".format(self.pcont_method))
mem = float(self.walker_buffer.nbytes) / (1024.0**3)
print("# Buffer size for communication: {:13.8e} GB".format(mem))
if mem > 2.0:
# TODO: FDM FIX THIS
print(" # Warning: Walker buffer size > 2GB. May run into MPI"
"issues.")
if not self.walker_type == "thermal":
walker_size = 3 + self.walkers[0].phi.size
if self.write_freq > 0:
self.write_restart = True
self.dsets = []
with h5py.File(self.write_file,'w',driver='mpio',comm=comm) as fh5:
for i in range(self.ntot_walkers):
fh5.create_dataset('walker_%d'%i, (walker_size,),
dtype=numpy.complex128)
else:
self.write_restart = False
if self.read_file is not None:
if verbose:
print("# Reading walkers from %s file series."%self.read_file)
self.read_walkers(comm)
self.target_weight = qmc.ntot_walkers
self.nw = qmc.nwalkers
self.set_total_weight(qmc.ntot_walkers)
[docs] def orthogonalise(self, trial, free_projection):
"""Orthogonalise all walkers.
Parameters
----------
trial : object
Trial wavefunction object.
free_projection : bool
True if doing free projection.
"""
for w in self.walkers:
detR = w.reortho(trial)
if free_projection:
(magn, dtheta) = cmath.polar(detR)
w.weight *= magn
w.phase *= cmath.exp(1j*dtheta)
[docs] def add_field_config(self, nprop_tot, nbp, system, dtype):
"""Add FieldConfig object to walker object.
Parameters
----------
nprop_tot : int
Total number of propagators to store for back propagation + itcf.
nbp : int
Number of back propagation steps.
nfields : int
Number of fields to store for each back propagation step.
dtype : type
Field configuration type.
"""
for w in self.walkers:
w.field_configs = FieldConfig(system.nfields, nprop_tot, nbp, dtype)
[docs] def copy_historic_wfn(self):
"""Copy current wavefunction to psi_n for next back propagation step."""
for (i,w) in enumerate(self.walkers):
numpy.copyto(self.walkers[i].phi_old, self.walkers[i].phi)
[docs] def copy_bp_wfn(self, phi_bp):
"""Copy back propagated wavefunction.
Parameters
----------
phi_bp : object
list of walker objects containing back propagated walkers.
"""
for (i, (w,wbp)) in enumerate(zip(self.walkers, phi_bp)):
numpy.copyto(self.walkers[i].phi_bp, wbp.phi)
[docs] def copy_init_wfn(self):
"""Copy current wavefunction to initial wavefunction.
The definition of the initial wavefunction depends on whether we are
calculating an ITCF or not.
"""
for (i,w) in enumerate(self.walkers):
numpy.copyto(self.walkers[i].phi_right, self.walkers[i].phi)
[docs] def pop_control(self, comm):
weights = numpy.array([abs(w.weight) for w in self.walkers])
if comm.rank == 0:
global_weights = numpy.empty(len(weights)*comm.size)
else:
global_weights = numpy.empty(len(weights)*comm.size)
comm.Allgather(weights, global_weights)
total_weight = sum(global_weights)
# Rescale weights to combat exponential decay/growth.
scale = total_weight / self.target_weight
if total_weight < 1e-8:
if comm.rank == 0:
print("# Warning: Total weight is {:13.8e}: "
.format(total_weight))
print("# Something is seriously wrong.")
sys.exit()
self.set_total_weight(total_weight)
# Todo: Just standardise information we want to send between routines.
for w in self.walkers:
w.unscaled_weight = w.weight
w.weight = w.weight / scale
if self.pcont_method == "comb":
global_weights = global_weights / scale
self.comb(comm, global_weights)
elif self.pcont_method == "pair_branch":
self.pair_branch(comm)
else:
if comm.rank == 0:
print("Unknown population control method.")
[docs] def comb(self, comm, weights):
"""Apply the comb method of population control / branching.
See Booth & Gubernatis PRE 80, 046704 (2009).
Parameters
----------
comm : MPI communicator
"""
# Need make a copy to since the elements in psi are only references to
# walker objects in memory. We don't want future changes in a given
# element of psi having unintended consequences.
# todo : add phase to walker for free projection
if comm.rank == 0:
parent_ix = numpy.zeros(len(weights), dtype='i')
else:
parent_ix = numpy.empty(len(weights), dtype='i')
if comm.rank == 0:
total_weight = sum(weights)
cprobs = numpy.cumsum(weights)
r = numpy.random.random()
comb = [(i+r) * (total_weight/self.target_weight) for i in
range(self.target_weight)]
iw = 0
ic = 0
while ic < len(comb):
if comb[ic] < cprobs[iw]:
parent_ix[iw] += 1
ic += 1
else:
iw += 1
data = {'ix': parent_ix}
else:
data = None
data = comm.bcast(data, root=0)
parent_ix = data['ix']
# Keep total weight saved for capping purposes.
# where returns a tuple (array,), selecting first element.
kill = numpy.where(parent_ix == 0)[0]
clone = numpy.where(parent_ix > 1)[0]
reqs = []
walker_buffers = []
# First initiate non-blocking sends of walkers.
for i, (c, k) in enumerate(zip(clone, kill)):
# Sending from current processor?
if c // self.nw == comm.rank:
# Location of walker to clone in local list.
clone_pos = c % self.nw
# copying walker data to intermediate buffer to avoid issues
# with accessing walker data during send. Might not be
# necessary.
dest_proc = k // self.nw
buff = self.walkers[clone_pos].get_buffer()
reqs.append(comm.Isend(buff, dest=dest_proc, tag=i))
# Now receive walkers on processors where walkers are to be killed.
for i, (c, k) in enumerate(zip(clone, kill)):
# Receiving to current processor?
if k // self.nw == comm.rank:
# Processor we are receiving from.
source_proc = c // self.nw
# Location of walker to kill in local list of walkers.
kill_pos = k % self.nw
comm.Recv(self.walker_buffer, source=source_proc, tag=i)
self.walkers[kill_pos].set_buffer(self.walker_buffer)
# Complete non-blocking send.
for rs in reqs:
rs.wait()
# Necessary?
comm.Barrier()
# Reset walker weight.
# TODO: check this.
for w in self.walkers:
w.weight = 1.0
[docs] def pair_branch(self, comm):
walker_info = [[abs(w.weight),1,comm.rank,comm.rank] for w in self.walkers]
glob_inf = comm.gather(walker_info, root=0)
# Want same random number seed used on all processors
if comm.rank == 0:
# Rescale weights.
glob_inf = numpy.array([item for sub in glob_inf for item in sub])
total_weight = sum(w[0] for w in glob_inf)
sort = numpy.argsort(glob_inf[:,0], kind='mergesort')
isort = numpy.argsort(sort, kind='mergesort')
glob_inf = glob_inf[sort]
s = 0
e = len(glob_inf) - 1
tags = []
isend = 0
while s < e:
if glob_inf[s][0] < self.min_weight or glob_inf[e][0] > self.max_weight:
# sum of paired walker weights
wab = glob_inf[s][0] + glob_inf[e][0]
r = numpy.random.rand()
if r < glob_inf[e][0] / wab:
# clone large weight walker
glob_inf[e][0] = 0.5 * wab
glob_inf[e][1] = 2
# Processor we will send duplicated walker to
glob_inf[e][3] = glob_inf[s][2]
send = glob_inf[s][2]
# Kill small weight walker
glob_inf[s][0] = 0.0
glob_inf[s][1] = 0
glob_inf[s][3] = glob_inf[e][2]
else:
# clone small weight walker
glob_inf[s][0] = 0.5 * wab
glob_inf[s][1] = 2
# Processor we will send duplicated walker to
glob_inf[s][3] = glob_inf[e][2]
send = glob_inf[e][2]
# Kill small weight walker
glob_inf[e][0] = 0.0
glob_inf[e][1] = 0
glob_inf[e][3] = glob_inf[s][2]
tags.append([send])
s += 1
e -= 1
else:
break
nw = self.nwalkers
glob_inf = glob_inf[isort].reshape((comm.size,nw,4))
else:
data = None
total_weight = 0
data = comm.scatter(glob_inf, root=0)
# Keep total weight saved for capping purposes.
walker_buffers = []
reqs = []
for iw, walker in enumerate(data):
if walker[1] > 1:
tag = comm.rank*len(walker_info) + walker[3]
self.walkers[iw].weight = walker[0]
buff = self.walkers[iw].get_buffer()
reqs.append(comm.Isend(buff,
dest=int(round(walker[3])),
tag=tag))
for iw, walker in enumerate(data):
if walker[1] == 0:
tag = walker[3]*len(walker_info) + comm.rank
comm.Recv(self.walker_buffer,
source=int(round(walker[3])),
tag=tag)
self.walkers[iw].set_buffer(self.walker_buffer)
for r in reqs:
r.wait()
[docs] def recompute_greens_function(self, trial, time_slice=None):
for w in self.walkers:
w.greens_function(trial, time_slice)
[docs] def set_total_weight(self, total_weight):
for w in self.walkers:
w.total_weight = total_weight
w.old_total_weight = w.total_weight
[docs] def reset(self, trial):
for w in self.walkers:
w.stack.reset()
w.stack.set_all(trial.dmat)
w.greens_function(trial)
w.weight = 1.0
w.phase = 1.0 + 0.0j
[docs] def get_write_buffer(self, i):
w = self.walkers[i]
buff = numpy.concatenate([[w.weight], [w.phase], [w.ot], w.phi.ravel()])
return buff
[docs] def set_walker_from_buffer(self, i, buff):
w = self.walkers[i]
w.weight = buff[0]
w.phase = buff[1]
w.ot = buff[2]
w.phi = buff[3:].reshape(self.walkers[i].phi.shape)
[docs] def write_walkers(self, comm):
start = time.time()
with h5py.File(self.write_file,'r+',driver='mpio',comm=comm) as fh5:
for (i,w) in enumerate(self.walkers):
ix = i + self.nwalkers*comm.rank
buff = self.get_write_buffer(i)
fh5['walker_%d'%ix][:] = self.get_write_buffer(i)
if comm.rank == 0:
print(" # Writing walkers to file.")
print(" # Time to write restart: {:13.8e} s"
.format(time.time()-start))
[docs] def read_walkers(self, comm):
with h5py.File(self.read_file, 'r') as fh5:
for (i,w) in enumerate(self.walkers):
try:
ix = i + self.nwalkers*comm.rank
self.set_walker_from_buffer(i, fh5['walker_%d'%ix][:])
except KeyError:
print(" # Could not read walker data from:"
" %s"%(self.read_file))