Calculating Bispectrum with dask-ms and without loops
Hackasteroid142 opened this issue · comments
- dask-ms version: 0.2.15
- Python version: 3.9.16
- Operating System: Ubuntu Linux 22.04
Hi everyone, I am trying to implement the Bispectrum method with dask-ms and without any for loop. Currently I have an implementation for this but if I do a compute() to the result my computer freezes.
This implementation consists of finding the combination for trios of antennas (i,j,k), finding the visibilities that match that have two associated antennas of the trio ([i,j]; [j,i] or [k,i]) and multiplying the visibilities obtained by pairs of antennas to obtain the bispectrum.
My idea is that when looking for the positions of the visibilities that meet the condition of pairs of antennas, this array has the same shape as the array of visibilities in order to multiply these with the array of posistions. so that the values that do not serve would become 0.
This solution does not require a lot of time to run but, as I said before, at the moment of doing compute() to see the result my computer freezes. So far this is the only way I have found to apply this method without for cycles, as in previous implementations I have tried with these but they take a long time to finish. For the same reason, I wanted to know if someone has been able to implement somenthing similar or if someone can think of a way to improve or change this, any help is welcome.
Below is the code I am currently implementing. If there is any doubt with what is written, please ask!.
from .transformer import Transformer
from ..base import VisibilityBis
from ..base.subms import SubMS
from ..base.dataset import Dataset
import xarray as xr
import numpy as np
import dask.array as da
import dask
import itertools
class Bispectrum(Transformer):
def __init__(self, ant_ref = 7, **kwargs):
self.ant_ref = ant_ref
super().__init__(**kwargs)
def bispectrum(self, filter_comb, vis, antenna1, antenna2, nchans, ncorrs):
arr_i = da.array(filter_comb)[:,0]
arr_j = da.array(filter_comb)[:,1]
arr_k = da.array(filter_comb)[:,2]
arr_j_aux = da.repeat(arr_j, ncorrs).reshape((len(arr_j), ncorrs))
arr_j_temp = da.tile(arr_j_aux, nchans).reshape((len(arr_j_aux), nchans, ncorrs))
arr_i_aux = da.repeat(arr_i, ncorrs).reshape((len(arr_i), ncorrs))
arr_i_temp = da.tile(arr_i_aux, nchans).reshape((len(arr_i_aux), nchans, ncorrs))
arr_k_aux = da.repeat(arr_k, ncorrs).reshape((len(arr_k), ncorrs))
arr_k_temp = da.tile(arr_k_aux, nchans).reshape((len(arr_k_aux), nchans, ncorrs))
ant1_aux = da.repeat(antenna1, ncorrs, axis=0).reshape((len(antenna1), ncorrs))
antenna1_temp = da.tile(ant1_aux, nchans).reshape((len(ant1_aux), nchans, ncorrs))
ant2_aux = da.repeat(antenna2, ncorrs, axis=0).reshape((len(antenna2), ncorrs))
antenna2_temp = da.tile(ant2_aux, nchans).reshape((len(ant2_aux), nchans, ncorrs))
ij = ((antenna1_temp[np.newaxis,:] == arr_i_temp[:,np.newaxis]) & (antenna2_temp[np.newaxis,:] == arr_j_temp[:,np.newaxis]))
jk = ((antenna1_temp[np.newaxis,:] == arr_j_temp[:,np.newaxis]) & (antenna2_temp[np.newaxis,:] == arr_k_temp[:,np.newaxis]))
ki = ((antenna1_temp[np.newaxis,:] == arr_k_temp[:,np.newaxis]) & (antenna2_temp[np.newaxis,:] == arr_i_temp[:,np.newaxis]))
vis_model_ij = vis_model * ij
vis_model_jk = vis_model * jk
vis_model_ki = vis_model * ki
vis_model_spec = vis_model_ij * vis_model_jk * vis_model_ki
return vis_model_spec
def transform(self) -> None:
comb = itertools.combinations(self.input_data.antenna.dataset.ROWID.data.compute(),3)
filter_comb = [i for i in comb if i[0] == self.ant_ref or i[1] == self.ant_ref or i[2] == self.ant_ref]
ms_list = []
for ms in self.input_data.ms_list:
flags = ms.visibilities.flag.data
spw_id = ms.spw_id
pol_id = ms.polarization_id
nchans = self.input_data.spws.nchans[spw_id]
ncorrs = self.input_data.polarization.ncorrs[pol_id]
vis = ms.visibilities.data * weight_broadcast * ~flags
antenna1 = ms.visibilities.antenna1.data
antenna2 = ms.visibilities.antenna2.data
vis_temp = da.vstack((vis,da.conj(vis)))
antenna1_temp = da.hstack((antenna1,antenna2))
antenna2_temp = da.hstack((antenna2,antenna1))
vis_bis = self.bispectrum(filter_comb, vis_temp,antenna1_temp, antenna2_temp, nchans, ncorrs)
ds = xr.Dataset(
data_vars = dict(
DATA=(['comb','nrow', 'nchans', 'ncorrs'], vis_bis),
)
)
visibility_obj = VisibilityBis(dataset=ds)
ms_obj = SubMS(
visibilities=visibility_obj,
)
ms_list.append(ms_obj)
return ms_list
Hi @Hackasteroid142! I have put together the following example, although there was some guesswork involved as the code above was not complete. If I follow correctly, you want to compute the three-point correlation function. I believe the following code (quite different from your example) can do it, although it needs further testing. This is by no means the most optimal implementation, but I would be interested to know if it runs/gives decent results and whether its performance is acceptable.
from daskms import xds_from_ms, xds_from_table
import itertools
import xarray
import numpy as np
import dask.array as da
from numba import njit
def compute_bispectrum(data, time, ant1, ant2, comb):
"""Python wrapper function."""
# Needs to be done outside numba due to implementation limitations.
utime, utime_inv = np.unique(time, return_inverse=True)
return _compute_bispectrum(data, comb, utime, utime_inv, ant1, ant2)
@njit(cache=True, nogil=True)
def _compute_bispectrum(data, comb, utime, utime_inv, ant1, ant2):
"""This function is compiled using numba, which makes loops fast.
Note that this is not an optimal implementation - that requires some
additional work. IF you do decide to use numba, not all python constructs
will work.
"""
n_comb = len(comb)
n_row, n_chan, n_corr = data.shape
n_utime = utime.size
bis = np.ones((n_comb, n_utime, n_chan, n_corr), dtype=np.complex64)
for row in range(n_row):
ut = utime_inv[row]
a1 = ant1[row]
a2 = ant2[row]
for ic, c in enumerate(comb):
if (a1 in c) and (a2 in c):
if (a1 == c[0]) and (a2 == c[-1]):
bis[ic, ut] *= data[row].conjugate()
else:
bis[ic, ut] *= data[row]
return bis
def transform(data_xdsl, ant_xds, ref_ant):
"""Compute the bispectrum per time, per dataset.
Note that this implementation is brittle - it requires that the data be
ordered by TIME, ANTENNA1 and ANTENNA2. It doesn't know how to handle
autocorrelations so make sure they are flagged.
"""
n_ant = ant_xds.dims["row"]
comb = itertools.combinations(range(n_ant), 3)
filter_comb = tuple([i for i in comb if ref_ant in i])
bispectrum_xdsl = []
for xds in data_xdsl:
data = xds.DATA.data
flags = xds.FLAG.data
weight = xds.WEIGHT_SPECTRUM.data
time = xds.TIME.data
vis = data * weight * ~flags # Bad/flagged points will be zero.
antenna1 = xds.ANTENNA1.data
antenna2 = xds.ANTENNA2.data
# This is advanced dask usage. Here we are mapping a function over all
# the blocks in the input. Note that we do not know how many unique
# times there are per chunk and so we cannot determine ntime in
# advance. This can be improved.
bispectrum = da.blockwise(
compute_bispectrum, ("triangle", "t", "f", "c"),
vis, ("t", "f", "c"),
time, ("t"),
antenna1, ("t"),
antenna2, ("t"),
filter_comb, None,
align_arrays=False,
adjust_chunks={"t": (np.nan,)*data.numblocks[0]},
new_axes={"triangle": len(filter_comb)},
dtype=np.complex64
)
xds = xarray.Dataset(
data_vars = dict(
BISPECTRUM=(
['triangle','ntime', 'nchans', 'ncorrs'], bispectrum
)
)
)
bispectrum_xdsl.append(xds)
return bispectrum_xdsl
if __name__ == "__main__":
ms = "~/reductions/3C147/msdir/C147_unflagged.MS"
data_xdsl = xds_from_ms(
ms,
index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
columns=(
"TIME", "ANTENNA1", "ANTENNA2", "DATA", "WEIGHT_SPECTRUM", "FLAG"
),
chunks={'row': -1}
)
ant_xds = xds_from_table(ms + "::ANTENNA")[0]
ref_ant = 7
bispectrum_xdsl = transform(data_xdsl, ant_xds, ref_ant)
results = da.compute(bispectrum_xdsl, scheduler="threads")
Hi @JSKenyon, thanks for your reply!. I was trying the code you posted and it doesn't give me good results. My test is based on the fact that a quality of Bispectrum is that when performing the multiplication of the three visibilities these are no longer affected by their phase, so the experiment is to run the code with the original visibilities and then run the code with the visibilities but affected by random phases. These random phases are by antennas so that the visibilities associated to the same antennas are affected by the same phase. Then to verify that this is correct the resulting visibilities in both cases should be the same, unfortunately this is not happening since they are giving me different results.
Here I leave the edited code to add the phase to the visibilities.
from daskms import xds_from_ms, xds_from_table
import itertools
import xarray
import numpy as np
import dask.array as da
from numba import njit
def compute_bispectrum(data, time, ant1, ant2, comb):
"""Python wrapper function."""
# Needs to be done outside numba due to implementation limitations.
utime, utime_inv = np.unique(time, return_inverse=True)
return _compute_bispectrum(data, comb, utime, utime_inv, ant1, ant2)
@njit(cache=True, nogil=True)
def _compute_bispectrum(data, comb, utime, utime_inv, ant1, ant2):
"""This function is compiled using numba, which makes loops fast.
Note that this is not an optimal implementation - that requires some
additional work. IF you do decide to use numba, not all python constructs
will work.
"""
n_comb = len(comb)
n_row, n_chan, n_corr = data.shape
n_utime = utime.size
bis = np.ones((n_comb, n_utime, n_chan, n_corr), dtype=np.complex64)
for row in range(n_row):
ut = utime_inv[row]
a1 = ant1[row]
a2 = ant2[row]
for ic, c in enumerate(comb):
if (a1 in c) and (a2 in c):
if (a1 == c[0]) and (a2 == c[-1]):
bis[ic, ut] *= data[row].conjugate()
else:
bis[ic, ut] *= data[row]
return bis
def transform(data_xdsl, ant_xds, ref_ant):
"""Compute the bispectrum per time, per dataset.
Note that this implementation is brittle - it requires that the data be
ordered by TIME, ANTENNA1 and ANTENNA2. It doesn't know how to handle
autocorrelations so make sure they are flagged.
"""
n_ant = ant_xds.dims["row"]
comb = itertools.combinations(range(n_ant), 3)
filter_comb = np.array([i for i in comb if ref_ant in i])
bispectrum_xdsl = []
phases = np.random.normal(0,1,size=n_ant).repeat(ncorrs).reshape((n_ant,nchans,ncorrs))
for xds in data_xdsl:
data = xds.DATA.data
flags = xds.FLAG.data
weight = xds.WEIGHT_SPECTRUM.data
time = xds.TIME.data
n_row, n_chan, n_corr = data.shape
vis = data * weight * ~flags # Bad/flagged points will be zero.
antenna1 = xds.ANTENNA1.data
antenna2 = xds.ANTENNA2.data
diff = phases[antenna1] - phases[antenna2]
# I dont know if its the best way to do this but the objective is to reshape the exponential to have the same shape as the visibilities.
exponential = da.exp(2j*np.pi*(diff)).repeat(n_corr).reshape((data.shape[0],n_corr)).repeat(n_chan).reshape(data.shape)
data_exp = exponential * vis
# This is advanced dask usage. Here we are mapping a function over all
# the blocks in the input. Note that we do not know how many unique
# times there are per chunk and so we cannot determine ntime in
# advance. This can be improved.
bispectrum = da.blockwise(
compute_bispectrum, ("triangle", "t", "f", "c"),
data_exp, ("t", "f", "c"),
time, ("t"),
antenna1, ("t"),
antenna2, ("t"),
filter_comb, None,
align_arrays=False,
adjust_chunks={"t": (np.nan,)*data.numblocks[0]},
new_axes={"triangle": len(filter_comb)},
dtype=np.complex64
)
xds = xarray.Dataset(
data_vars = dict(
BISPECTRUM=(
['triangle','ntime', 'nchans', 'ncorrs'], bispectrum
)
)
)
bispectrum_xdsl.append(xds)
return bispectrum_xdsl
if __name__ == "__main__":
ms = '/home/hackasteroid/Documentos/tesis/tesis_pruebas/HD143006/server/HD143006_continuum.ms'
data_xdsl = xds_from_ms(
ms,
index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
columns=(
"TIME", "ANTENNA1", "ANTENNA2", "DATA", "WEIGHT_SPECTRUM", "FLAG"
),
chunks={'row': -1}
)
ant_xds = xds_from_table(ms + "::ANTENNA")[0]
ref_ant = 7
bispectrum_xdsl = transform(data_xdsl, ant_xds, ref_ant)
results = da.compute(bispectrum_xdsl, scheduler="threads")
Hi @Hackasteroid142! I have now made an effort to test that the code produces the correct result and I am fairly confident that it is now working (subject to the caveats listed in the code). Note that the quantity you are computing is equivalent to the closure phase. The amended code demonstrates that the closure phase remains zero (or very close due to finite accuracy) on all triangles when a random antenna-based phase is introduced. As an extra tip, when posting code to github you can add a language after the opening triple quotes to enable syntax highlighting - I have edited your previous posts accordingly if you would like to see how.
from daskms import xds_from_ms, xds_from_table
import itertools
import xarray
import numpy as np
import dask.array as da
from numba import njit
def compute_bispectrum(data, time, ant1, ant2, comb):
"""Python wrapper function."""
# Needs to be done outside numba due to implementation limitations.
utime, utime_inv = np.unique(time, return_inverse=True)
bis, cnt = _compute_bispectrum(data, comb, utime, utime_inv, ant1, ant2)
# NOTE: Removed from numba due to slightly weird results.
bis[cnt != 3] = 0 # Zero bispectrum on bad triangles (missing data).
return bis
@njit(cache=True, nogil=True)
def _compute_bispectrum(data, comb, utime, utime_inv, ant1, ant2):
"""This function is compiled using numba, which makes loops fast.
Note that this is not an optimal implementation - that requires some
additional work. If you do decide to use numba, not all python constructs
will work.
"""
n_comb = len(comb)
n_row, n_chan, n_corr = data.shape
n_utime = utime.size
bis = np.ones((n_comb, n_utime, n_chan, n_corr), dtype=np.complex128)
cnt = np.zeros((n_comb, n_utime, n_chan, n_corr), dtype=np.int8)
for row in range(n_row):
ut = utime_inv[row]
a1 = ant1[row]
a2 = ant2[row]
for ic, c in enumerate(comb):
if (a1 in c) and (a2 in c):
if (a1 == c[0]) and (a2 == c[-1]):
bis[ic, ut] *= data[row].conjugate()
else:
bis[ic, ut] *= data[row]
cnt[ic, ut] += 1
return bis, cnt
def transform(data_xdsl, ant_xds, ref_ant):
"""Compute the bispectrum per time, per dataset.
Note that this implementation is brittle - it requires that the data be
ordered by TIME, ANTENNA1 and ANTENNA2. It doesn't know how to handle
autocorrelations so make sure they are flagged.
"""
n_ant = ant_xds.dims["row"]
comb = itertools.combinations(range(n_ant), 3)
filter_comb = tuple([i for i in comb if ref_ant in i])
bispectrum_xdsl = []
for xds in data_xdsl:
data = xds.DATA.data
flags = xds.FLAG.data
weight = xds.WEIGHT_SPECTRUM.data
time = xds.TIME.data
vis = data * weight * ~flags # Bad/flagged points will be zero.
antenna1 = xds.ANTENNA1.data
antenna2 = xds.ANTENNA2.data
# This is advanced dask usage. Here we are mapping a function over all
# the blocks in the input. Note that we do not know how many unique
# times there are per chunk and so we cannot determine ntime in
# advance. This can be improved.
bispectrum = da.blockwise(
compute_bispectrum, ("triangle", "t", "f", "c"),
vis, ("t", "f", "c"),
time, ("t"),
antenna1, ("t"),
antenna2, ("t"),
filter_comb, None,
align_arrays=False,
adjust_chunks={"t": (np.nan,)*data.numblocks[0]},
new_axes={"triangle": len(filter_comb)},
dtype=np.complex128
)
xds = xarray.Dataset(
data_vars=dict(
BISPECTRUM=(
('triangle','ntime', 'nchans', 'ncorrs'), bispectrum
)
),
coords=dict(
tri0=(('triangle',), np.array([i[0] for i in filter_comb])),
tri1=(('triangle',), np.array([i[1] for i in filter_comb])),
tri2=(('triangle',), np.array([i[2] for i in filter_comb]))
)
)
bispectrum_xdsl.append(xds)
return bispectrum_xdsl
def apply_per_antenna_phase(data_xdsl, n_ant):
"""Applies a random phase per-antenna.
Currently lacks time and frequency variation but that is irrelevant for
this test.
"""
output_xdsl = []
for xds in data_xdsl:
data = xds.DATA.data
antenna1 = xds.ANTENNA1.data
antenna2 = xds.ANTENNA2.data
phases = da.exp(2*np.pi*1j*da.random.random(n_ant))
a1_phase = phases[antenna1]
a2_phase = phases[antenna2].conj()
corrupted_data = \
a1_phase[:, None, None] * data * a2_phase[:, None, None]
output_xds = xds.assign({"DATA": (xds.DATA.dims, corrupted_data)})
output_xdsl.append(output_xds)
return output_xdsl
if __name__ == "__main__":
ms = "~/reductions/3C147/msdir/C147_unflagged.MS"
data_xdsl = xds_from_ms(
ms,
index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
columns=(
"TIME", "ANTENNA1", "ANTENNA2", "DATA", "WEIGHT_SPECTRUM", "FLAG"
),
chunks={'row': -1}
)
# Replace data with ones for testing purposes. Bispectrum/closure phase
# should be zero as this is equivalent to a point source at the field
# centre (this is not true for the off-diagonal terms, but we don't care).
data_xdsl = [
xds.assign(
{
"DATA": (
xds.DATA.dims,
da.ones(
xds.DATA.data.shape,
dtype=np.complex128,
chunks=(-1,-1,-1)
)
)
}
) for xds in data_xdsl
]
ant_xds = xds_from_table(ms + "::ANTENNA")[0]
n_ant = ant_xds.dims["row"]
ref_ant = 0
# Create and apply a random phase per antenna. Lacks time and frequency
# variation but that shouldn't matter for testing purposes.
data_xdsl = apply_per_antenna_phase(data_xdsl, n_ant)
bispectrum_xdsl = transform(data_xdsl, ant_xds, ref_ant)
results = da.compute(bispectrum_xdsl, scheduler="threads")
Hi @JSKenyon, again thanks for the help and advice, I'm very new to this so I'm still learning. I have a question about something you say in the last post, is that you say that for a point source the closing phases are zeros. I really don't understand why, because to me if each visibility is one, then the product of the three visibilities (bispectrum) has to be one and not zero, regardless of the random phases. However, I have performed your test and the results are Zero or near Zero for the bispectrum, but as I said, I don't understand why.
I am sending you the results for a combination of antennas that I have obtained, specifically (0,1,2).
array([[[1.55230006e-07-1.32348898e-23j, 2.36157651e-07+0.00000000e+00j]],
[[1.55374352e-07+0.00000000e+00j, 2.36367621e-07-1.32348898e-23j]],
[[1.55511986e-07-1.32348898e-23j, 2.36574375e-07+0.00000000e+00j]],
[[1.55653593e-07+0.00000000e+00j, 2.36796126e-07+0.00000000e+00j]],
[[1.55795519e-07+0.00000000e+00j, 2.37003686e-07+0.00000000e+00j]],
[[1.56199323e-07+0.00000000e+00j, 2.37597664e-07-1.32348898e-23j]],
[[1.56339270e-07+1.32348898e-23j, 2.37812000e-07-2.64697796e-23j]],
[[1.56482272e-07+0.00000000e+00j, 2.38022503e-07-1.32348898e-23j]],
[[1.56623752e-07+0.00000000e+00j, 2.38233461e-07+0.00000000e+00j]],
[[1.56768598e-07-1.32348898e-23j, 2.38446316e-07+0.00000000e+00j]],
...
[[1.69481860e-07-1.32348898e-23j, 2.59901916e-07-1.32348898e-23j]],
[[1.69416613e-07-1.32348898e-23j, 2.59840553e-07-1.32348898e-23j]],
[[1.69351884e-07-1.32348898e-23j, 2.59772824e-07-1.32348898e-23j]],
[[1.69287249e-07-1.32348898e-23j, 2.59715085e-07-2.64697796e-23j]],
[[1.69224460e-07-1.32348898e-23j, 2.59649774e-07+0.00000000e+00j]],
[[1.69043460e-07+0.00000000e+00j, 2.59481203e-07-2.64697796e-23j]],
[[1.68977989e-07+0.00000000e+00j, 2.59412998e-07+1.32348898e-23j]],
[[1.68913338e-07+0.00000000e+00j, 2.59349050e-07-1.32348898e-23j]],
[[1.68848525e-07+0.00000000e+00j, 2.59287877e-07+0.00000000e+00j]],
[[1.68784283e-07+0.00000000e+00j, 2.59229124e-07-2.64697796e-23j]]])
Hi @Hackasteroid142. Specifically, I said that the closure phase (the phase of the bispectrum) should be zero, not the bispectrum itself. The reason is that the Fourier transform of a point source at the field center is real and constant i.e. it has zero phase. This was a just an easy way to check that I was forming the triangles correctly, as the phase would not be zero if they were incorrect.
As to why you are seeing values close to zero you would need to check a few things:
- Visibility magnitudes: what are the magnitudes in the data?
- Weights: Are your weights non-zero/sensible?
- Flags: Are you looking at a triangle that includes a flagged value?
Hi @JSKenyon, finally the reason why I was getting those results was because of the weights that multiply the visibilities. If I remove this operation it gives me as a result the 1 for the bispectrum, so now I understand better the reason of the previous result. Really, thank you very much for all the help.
Great! In that case I will go ahead and close this issue for now.
Many thanks for handling this @JSKenyon