ratt-ru / dask-ms

Implementation of a dask/xarray dataset backed by a CASA MS

Home Page:https://dask-ms.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Calculating Bispectrum with dask-ms and without loops

Hackasteroid142 opened this issue · comments

  • dask-ms version: 0.2.15
  • Python version: 3.9.16
  • Operating System: Ubuntu Linux 22.04

Hi everyone, I am trying to implement the Bispectrum method with dask-ms and without any for loop. Currently I have an implementation for this but if I do a compute() to the result my computer freezes.
This implementation consists of finding the combination for trios of antennas (i,j,k), finding the visibilities that match that have two associated antennas of the trio ([i,j]; [j,i] or [k,i]) and multiplying the visibilities obtained by pairs of antennas to obtain the bispectrum.
My idea is that when looking for the positions of the visibilities that meet the condition of pairs of antennas, this array has the same shape as the array of visibilities in order to multiply these with the array of posistions. so that the values that do not serve would become 0.
This solution does not require a lot of time to run but, as I said before, at the moment of doing compute() to see the result my computer freezes. So far this is the only way I have found to apply this method without for cycles, as in previous implementations I have tried with these but they take a long time to finish. For the same reason, I wanted to know if someone has been able to implement somenthing similar or if someone can think of a way to improve or change this, any help is welcome.
Below is the code I am currently implementing. If there is any doubt with what is written, please ask!.

from .transformer import Transformer
from ..base import VisibilityBis
from ..base.subms import SubMS
from ..base.dataset import Dataset
import xarray as xr
import numpy as np
import dask.array as da
import dask 
import itertools

class Bispectrum(Transformer):
    def __init__(self, ant_ref = 7,  **kwargs):
        self.ant_ref = ant_ref
        super().__init__(**kwargs)

    
    def bispectrum(self, filter_comb, vis, antenna1, antenna2, nchans, ncorrs):
        arr_i = da.array(filter_comb)[:,0]
        arr_j = da.array(filter_comb)[:,1]
        arr_k = da.array(filter_comb)[:,2]

        arr_j_aux = da.repeat(arr_j,  ncorrs).reshape((len(arr_j), ncorrs))
        arr_j_temp = da.tile(arr_j_aux, nchans).reshape((len(arr_j_aux), nchans, ncorrs))
        arr_i_aux = da.repeat(arr_i,  ncorrs).reshape((len(arr_i), ncorrs))
        arr_i_temp = da.tile(arr_i_aux, nchans).reshape((len(arr_i_aux), nchans, ncorrs))
        arr_k_aux = da.repeat(arr_k,  ncorrs).reshape((len(arr_k), ncorrs))
        arr_k_temp = da.tile(arr_k_aux, nchans).reshape((len(arr_k_aux), nchans, ncorrs))

        ant1_aux = da.repeat(antenna1, ncorrs, axis=0).reshape((len(antenna1), ncorrs))
        antenna1_temp = da.tile(ant1_aux, nchans).reshape((len(ant1_aux), nchans, ncorrs))
        ant2_aux = da.repeat(antenna2, ncorrs, axis=0).reshape((len(antenna2), ncorrs))
        antenna2_temp = da.tile(ant2_aux, nchans).reshape((len(ant2_aux), nchans, ncorrs))
        

        ij = ((antenna1_temp[np.newaxis,:] == arr_i_temp[:,np.newaxis]) & (antenna2_temp[np.newaxis,:] == arr_j_temp[:,np.newaxis]))
        jk = ((antenna1_temp[np.newaxis,:] == arr_j_temp[:,np.newaxis]) & (antenna2_temp[np.newaxis,:] == arr_k_temp[:,np.newaxis]))
        ki = ((antenna1_temp[np.newaxis,:] == arr_k_temp[:,np.newaxis]) & (antenna2_temp[np.newaxis,:] == arr_i_temp[:,np.newaxis]))


        vis_model_ij = vis_model * ij
        vis_model_jk = vis_model * jk 
        vis_model_ki = vis_model * ki

        vis_model_spec = vis_model_ij * vis_model_jk * vis_model_ki

        return vis_model_spec


    def transform(self) -> None:
        comb = itertools.combinations(self.input_data.antenna.dataset.ROWID.data.compute(),3)
        filter_comb = [i for i in comb if i[0] == self.ant_ref or i[1] == self.ant_ref or i[2] == self.ant_ref]
        ms_list = []
        for ms in self.input_data.ms_list:
            flags = ms.visibilities.flag.data
            spw_id = ms.spw_id
            pol_id = ms.polarization_id

            nchans = self.input_data.spws.nchans[spw_id]
            ncorrs = self.input_data.polarization.ncorrs[pol_id]

            vis = ms.visibilities.data * weight_broadcast * ~flags

            antenna1 = ms.visibilities.antenna1.data
            antenna2 = ms.visibilities.antenna2.data

            vis_temp = da.vstack((vis,da.conj(vis)))
            antenna1_temp = da.hstack((antenna1,antenna2))
            antenna2_temp = da.hstack((antenna2,antenna1))

            vis_bis = self.bispectrum(filter_comb, vis_temp,antenna1_temp, antenna2_temp, nchans, ncorrs)

            ds = xr.Dataset(
                data_vars = dict(
                    DATA=(['comb','nrow', 'nchans', 'ncorrs'], vis_bis),
                )
            )

            visibility_obj = VisibilityBis(dataset=ds)
            ms_obj = SubMS(
                visibilities=visibility_obj,
            )
            ms_list.append(ms_obj)

        return ms_list

Hi @Hackasteroid142! I have put together the following example, although there was some guesswork involved as the code above was not complete. If I follow correctly, you want to compute the three-point correlation function. I believe the following code (quite different from your example) can do it, although it needs further testing. This is by no means the most optimal implementation, but I would be interested to know if it runs/gives decent results and whether its performance is acceptable.

from daskms import xds_from_ms, xds_from_table
import itertools
import xarray
import numpy as np
import dask.array as da
from numba import njit


def compute_bispectrum(data, time, ant1, ant2, comb):
    """Python wrapper function."""

    # Needs to be done outside numba due to implementation limitations.
    utime, utime_inv = np.unique(time, return_inverse=True)
    
    return _compute_bispectrum(data, comb, utime, utime_inv, ant1, ant2)


@njit(cache=True, nogil=True)
def _compute_bispectrum(data, comb, utime, utime_inv, ant1, ant2):
    """This function is compiled using numba, which makes loops fast.

    Note that this is not an optimal implementation - that requires some
    additional work. IF you do decide to use numba, not all python constructs
    will work.
    """

    n_comb = len(comb)
    n_row, n_chan, n_corr = data.shape
    n_utime = utime.size

    bis = np.ones((n_comb, n_utime, n_chan, n_corr), dtype=np.complex64)

    for row in range(n_row):
        
        ut = utime_inv[row]

        a1 = ant1[row]
        a2 = ant2[row]

        for ic, c in enumerate(comb):

            if (a1 in c) and (a2 in c):
                if (a1 == c[0]) and (a2 == c[-1]):
                    bis[ic, ut] *= data[row].conjugate() 
                else:
                    bis[ic, ut] *= data[row]

    return bis


def transform(data_xdsl, ant_xds, ref_ant):
    """Compute the bispectrum per time, per dataset.

    Note that this implementation is brittle - it requires that the data be
    ordered by TIME, ANTENNA1 and ANTENNA2. It doesn't know how to handle
    autocorrelations so make sure they are flagged.
    """

    n_ant = ant_xds.dims["row"]

    comb = itertools.combinations(range(n_ant), 3)
    filter_comb = tuple([i for i in comb if ref_ant in i])

    bispectrum_xdsl = []

    for xds in data_xdsl:
        data = xds.DATA.data
        flags = xds.FLAG.data
        weight = xds.WEIGHT_SPECTRUM.data
        time = xds.TIME.data

        vis = data * weight * ~flags  # Bad/flagged points will be zero.

        antenna1 = xds.ANTENNA1.data
        antenna2 = xds.ANTENNA2.data

        # This is advanced dask usage. Here we are mapping a function over all
        # the blocks in the input. Note that we do not know how many unique
        # times there are per chunk and so we cannot determine ntime in
        # advance. This can be improved.
        bispectrum = da.blockwise(
            compute_bispectrum, ("triangle", "t", "f", "c"),
            vis, ("t", "f", "c"),
            time, ("t"),
            antenna1, ("t"), 
            antenna2, ("t"),
            filter_comb, None,
            align_arrays=False,
            adjust_chunks={"t": (np.nan,)*data.numblocks[0]},
            new_axes={"triangle": len(filter_comb)},
            dtype=np.complex64
        )

        xds = xarray.Dataset(
            data_vars = dict(
                BISPECTRUM=(
                    ['triangle','ntime', 'nchans', 'ncorrs'], bispectrum
                )
            )
        )

        bispectrum_xdsl.append(xds)

    return bispectrum_xdsl


if __name__ == "__main__":

    ms = "~/reductions/3C147/msdir/C147_unflagged.MS"
    data_xdsl = xds_from_ms(
        ms,
        index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
        group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
        columns=(
            "TIME", "ANTENNA1", "ANTENNA2", "DATA", "WEIGHT_SPECTRUM", "FLAG"
        ),
        chunks={'row': -1}
    )

    ant_xds = xds_from_table(ms + "::ANTENNA")[0]

    ref_ant = 7

    bispectrum_xdsl = transform(data_xdsl, ant_xds, ref_ant)

    results = da.compute(bispectrum_xdsl, scheduler="threads")

Hi @JSKenyon, thanks for your reply!. I was trying the code you posted and it doesn't give me good results. My test is based on the fact that a quality of Bispectrum is that when performing the multiplication of the three visibilities these are no longer affected by their phase, so the experiment is to run the code with the original visibilities and then run the code with the visibilities but affected by random phases. These random phases are by antennas so that the visibilities associated to the same antennas are affected by the same phase. Then to verify that this is correct the resulting visibilities in both cases should be the same, unfortunately this is not happening since they are giving me different results.
Here I leave the edited code to add the phase to the visibilities.

from daskms import xds_from_ms, xds_from_table
import itertools
import xarray
import numpy as np
import dask.array as da
from numba import njit


def compute_bispectrum(data, time, ant1, ant2, comb):
    """Python wrapper function."""

    # Needs to be done outside numba due to implementation limitations.
    utime, utime_inv = np.unique(time, return_inverse=True)
    
    return _compute_bispectrum(data, comb, utime, utime_inv, ant1, ant2)


@njit(cache=True, nogil=True)
def _compute_bispectrum(data, comb, utime, utime_inv, ant1, ant2):
    """This function is compiled using numba, which makes loops fast.

    Note that this is not an optimal implementation - that requires some
    additional work. IF you do decide to use numba, not all python constructs
    will work.
    """

    n_comb = len(comb)
    n_row, n_chan, n_corr = data.shape
    n_utime = utime.size

    bis = np.ones((n_comb, n_utime, n_chan, n_corr), dtype=np.complex64)

    for row in range(n_row):
        
        ut = utime_inv[row]

        a1 = ant1[row]
        a2 = ant2[row]

        for ic, c in enumerate(comb):

            if (a1 in c) and (a2 in c):
                if (a1 == c[0]) and (a2 == c[-1]):
                    bis[ic, ut] *= data[row].conjugate() 
                else:
                    bis[ic, ut] *= data[row]

    return bis


def transform(data_xdsl, ant_xds, ref_ant):
    """Compute the bispectrum per time, per dataset.

    Note that this implementation is brittle - it requires that the data be
    ordered by TIME, ANTENNA1 and ANTENNA2. It doesn't know how to handle
    autocorrelations so make sure they are flagged.
    """

    n_ant = ant_xds.dims["row"]

    comb = itertools.combinations(range(n_ant), 3)
    filter_comb = np.array([i for i in comb if ref_ant in i])

    bispectrum_xdsl = []
    phases = np.random.normal(0,1,size=n_ant).repeat(ncorrs).reshape((n_ant,nchans,ncorrs))

    for xds in data_xdsl:
        
        data = xds.DATA.data
        flags = xds.FLAG.data
        weight = xds.WEIGHT_SPECTRUM.data
        time = xds.TIME.data
        n_row, n_chan, n_corr = data.shape

        vis = data * weight * ~flags  # Bad/flagged points will be zero.

        antenna1 = xds.ANTENNA1.data
        antenna2 = xds.ANTENNA2.data

        diff = phases[antenna1] - phases[antenna2]

        # I dont know if its the best way to do this but the objective is to reshape the exponential to have the same shape as the visibilities. 
        exponential = da.exp(2j*np.pi*(diff)).repeat(n_corr).reshape((data.shape[0],n_corr)).repeat(n_chan).reshape(data.shape)

        data_exp = exponential * vis
        # This is advanced dask usage. Here we are mapping a function over all
        # the blocks in the input. Note that we do not know how many unique
        # times there are per chunk and so we cannot determine ntime in
        # advance. This can be improved.
        bispectrum = da.blockwise(
            compute_bispectrum, ("triangle", "t", "f", "c"),
            data_exp, ("t", "f", "c"),
            time, ("t"),
            antenna1, ("t"), 
            antenna2, ("t"),
            filter_comb, None,
            align_arrays=False,
            adjust_chunks={"t": (np.nan,)*data.numblocks[0]},
            new_axes={"triangle": len(filter_comb)},
            dtype=np.complex64
        )

        xds = xarray.Dataset(
            data_vars = dict(
                BISPECTRUM=(
                    ['triangle','ntime', 'nchans', 'ncorrs'], bispectrum
                )
            )
        )

        bispectrum_xdsl.append(xds)

    return bispectrum_xdsl



if __name__ == "__main__":

    ms = '/home/hackasteroid/Documentos/tesis/tesis_pruebas/HD143006/server/HD143006_continuum.ms'
    data_xdsl = xds_from_ms(
        ms,
        index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
        group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
        columns=(
            "TIME", "ANTENNA1", "ANTENNA2", "DATA", "WEIGHT_SPECTRUM", "FLAG"
        ),
        chunks={'row': -1}
    )

    ant_xds = xds_from_table(ms + "::ANTENNA")[0]

    ref_ant = 7

    bispectrum_xdsl = transform(data_xdsl, ant_xds, ref_ant)

    results = da.compute(bispectrum_xdsl, scheduler="threads")

Hi @Hackasteroid142! I have now made an effort to test that the code produces the correct result and I am fairly confident that it is now working (subject to the caveats listed in the code). Note that the quantity you are computing is equivalent to the closure phase. The amended code demonstrates that the closure phase remains zero (or very close due to finite accuracy) on all triangles when a random antenna-based phase is introduced. As an extra tip, when posting code to github you can add a language after the opening triple quotes to enable syntax highlighting - I have edited your previous posts accordingly if you would like to see how.

from daskms import xds_from_ms, xds_from_table
import itertools
import xarray
import numpy as np
import dask.array as da
from numba import njit


def compute_bispectrum(data, time, ant1, ant2, comb):
    """Python wrapper function."""

    # Needs to be done outside numba due to implementation limitations.
    utime, utime_inv = np.unique(time, return_inverse=True)
    
    bis, cnt = _compute_bispectrum(data, comb, utime, utime_inv, ant1, ant2)

    # NOTE: Removed from numba due to slightly weird results.
    bis[cnt != 3] = 0  # Zero bispectrum on bad triangles (missing data).

    return bis


@njit(cache=True, nogil=True)
def _compute_bispectrum(data, comb, utime, utime_inv, ant1, ant2):
    """This function is compiled using numba, which makes loops fast.

    Note that this is not an optimal implementation - that requires some
    additional work. If you do decide to use numba, not all python constructs
    will work.
    """

    n_comb = len(comb)
    n_row, n_chan, n_corr = data.shape
    n_utime = utime.size

    bis = np.ones((n_comb, n_utime, n_chan, n_corr), dtype=np.complex128)
    cnt = np.zeros((n_comb, n_utime, n_chan, n_corr), dtype=np.int8)

    for row in range(n_row):
        
        ut = utime_inv[row]

        a1 = ant1[row]
        a2 = ant2[row]

        for ic, c in enumerate(comb):

            if (a1 in c) and (a2 in c):
                if (a1 == c[0]) and (a2 == c[-1]):
                    bis[ic, ut] *= data[row].conjugate() 
                else:
                    bis[ic, ut] *= data[row]
                cnt[ic, ut] += 1

    return bis, cnt


def transform(data_xdsl, ant_xds, ref_ant):
    """Compute the bispectrum per time, per dataset.

    Note that this implementation is brittle - it requires that the data be
    ordered by TIME, ANTENNA1 and ANTENNA2. It doesn't know how to handle
    autocorrelations so make sure they are flagged.
    """

    n_ant = ant_xds.dims["row"]

    comb = itertools.combinations(range(n_ant), 3)
    filter_comb = tuple([i for i in comb if ref_ant in i])

    bispectrum_xdsl = []

    for xds in data_xdsl:
        data = xds.DATA.data
        flags = xds.FLAG.data
        weight = xds.WEIGHT_SPECTRUM.data
        time = xds.TIME.data

        vis = data * weight * ~flags  # Bad/flagged points will be zero.

        antenna1 = xds.ANTENNA1.data
        antenna2 = xds.ANTENNA2.data

        # This is advanced dask usage. Here we are mapping a function over all
        # the blocks in the input. Note that we do not know how many unique
        # times there are per chunk and so we cannot determine ntime in
        # advance. This can be improved.
        bispectrum = da.blockwise(
            compute_bispectrum, ("triangle", "t", "f", "c"),
            vis, ("t", "f", "c"),
            time, ("t"),
            antenna1, ("t"), 
            antenna2, ("t"),
            filter_comb, None,
            align_arrays=False,
            adjust_chunks={"t": (np.nan,)*data.numblocks[0]},
            new_axes={"triangle": len(filter_comb)},
            dtype=np.complex128
        )

        xds = xarray.Dataset(
            data_vars=dict(
                BISPECTRUM=(
                    ('triangle','ntime', 'nchans', 'ncorrs'), bispectrum
                )
            ),
            coords=dict(
                tri0=(('triangle',), np.array([i[0] for i in filter_comb])),
                tri1=(('triangle',), np.array([i[1] for i in filter_comb])),
                tri2=(('triangle',), np.array([i[2] for i in filter_comb]))
            )
        )

        bispectrum_xdsl.append(xds)

    return bispectrum_xdsl


def apply_per_antenna_phase(data_xdsl, n_ant):
    """Applies a random phase per-antenna.
    
    Currently lacks time and frequency variation but that is irrelevant for
    this test.    
    """

    output_xdsl = []

    for xds in data_xdsl:

        data = xds.DATA.data
        antenna1 = xds.ANTENNA1.data
        antenna2 = xds.ANTENNA2.data

        phases = da.exp(2*np.pi*1j*da.random.random(n_ant))

        a1_phase = phases[antenna1]
        a2_phase = phases[antenna2].conj()

        corrupted_data = \
            a1_phase[:, None, None] * data * a2_phase[:, None, None]

        output_xds = xds.assign({"DATA": (xds.DATA.dims, corrupted_data)})

        output_xdsl.append(output_xds)

    return output_xdsl


if __name__ == "__main__":

    ms = "~/reductions/3C147/msdir/C147_unflagged.MS"
    data_xdsl = xds_from_ms(
        ms,
        index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
        group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
        columns=(
            "TIME", "ANTENNA1", "ANTENNA2", "DATA", "WEIGHT_SPECTRUM", "FLAG"
        ),
        chunks={'row': -1}
    )

    # Replace data with ones for testing purposes. Bispectrum/closure phase
    # should be zero as this is equivalent to a point source at the field 
    # centre (this is not true for the off-diagonal terms, but we don't care).
    data_xdsl = [
        xds.assign(
            {
                "DATA": (
                    xds.DATA.dims,
                    da.ones(
                        xds.DATA.data.shape,
                        dtype=np.complex128,
                        chunks=(-1,-1,-1)
                    )
                )
            }
        ) for xds in data_xdsl
    ]    

    ant_xds = xds_from_table(ms + "::ANTENNA")[0]

    n_ant = ant_xds.dims["row"]
    ref_ant = 0

    # Create and apply a random phase per antenna. Lacks time and frequency
    # variation but that shouldn't matter for testing purposes.
    data_xdsl = apply_per_antenna_phase(data_xdsl, n_ant)

    bispectrum_xdsl = transform(data_xdsl, ant_xds, ref_ant)

    results = da.compute(bispectrum_xdsl, scheduler="threads")

Hi @JSKenyon, again thanks for the help and advice, I'm very new to this so I'm still learning. I have a question about something you say in the last post, is that you say that for a point source the closing phases are zeros. I really don't understand why, because to me if each visibility is one, then the product of the three visibilities (bispectrum) has to be one and not zero, regardless of the random phases. However, I have performed your test and the results are Zero or near Zero for the bispectrum, but as I said, I don't understand why.
I am sending you the results for a combination of antennas that I have obtained, specifically (0,1,2).

array([[[1.55230006e-07-1.32348898e-23j, 2.36157651e-07+0.00000000e+00j]],

       [[1.55374352e-07+0.00000000e+00j, 2.36367621e-07-1.32348898e-23j]],

       [[1.55511986e-07-1.32348898e-23j, 2.36574375e-07+0.00000000e+00j]],

       [[1.55653593e-07+0.00000000e+00j, 2.36796126e-07+0.00000000e+00j]],

       [[1.55795519e-07+0.00000000e+00j, 2.37003686e-07+0.00000000e+00j]],

       [[1.56199323e-07+0.00000000e+00j, 2.37597664e-07-1.32348898e-23j]],

       [[1.56339270e-07+1.32348898e-23j, 2.37812000e-07-2.64697796e-23j]],

       [[1.56482272e-07+0.00000000e+00j, 2.38022503e-07-1.32348898e-23j]],

       [[1.56623752e-07+0.00000000e+00j, 2.38233461e-07+0.00000000e+00j]],

       [[1.56768598e-07-1.32348898e-23j, 2.38446316e-07+0.00000000e+00j]],

...

       [[1.69481860e-07-1.32348898e-23j, 2.59901916e-07-1.32348898e-23j]],

       [[1.69416613e-07-1.32348898e-23j, 2.59840553e-07-1.32348898e-23j]],

       [[1.69351884e-07-1.32348898e-23j, 2.59772824e-07-1.32348898e-23j]],

       [[1.69287249e-07-1.32348898e-23j, 2.59715085e-07-2.64697796e-23j]],

       [[1.69224460e-07-1.32348898e-23j, 2.59649774e-07+0.00000000e+00j]],

       [[1.69043460e-07+0.00000000e+00j, 2.59481203e-07-2.64697796e-23j]],

       [[1.68977989e-07+0.00000000e+00j, 2.59412998e-07+1.32348898e-23j]],

       [[1.68913338e-07+0.00000000e+00j, 2.59349050e-07-1.32348898e-23j]],

       [[1.68848525e-07+0.00000000e+00j, 2.59287877e-07+0.00000000e+00j]],

       [[1.68784283e-07+0.00000000e+00j, 2.59229124e-07-2.64697796e-23j]]])

Hi @Hackasteroid142. Specifically, I said that the closure phase (the phase of the bispectrum) should be zero, not the bispectrum itself. The reason is that the Fourier transform of a point source at the field center is real and constant i.e. it has zero phase. This was a just an easy way to check that I was forming the triangles correctly, as the phase would not be zero if they were incorrect.

As to why you are seeing values close to zero you would need to check a few things:

  • Visibility magnitudes: what are the magnitudes in the data?
  • Weights: Are your weights non-zero/sensible?
  • Flags: Are you looking at a triangle that includes a flagged value?

Hi @JSKenyon, finally the reason why I was getting those results was because of the weights that multiply the visibilities. If I remove this operation it gives me as a result the 1 for the bispectrum, so now I understand better the reason of the previous result. Really, thank you very much for all the help.

Great! In that case I will go ahead and close this issue for now.

Many thanks for handling this @JSKenyon