Aggregation

Xarray-Beam can perform efficient distributed data aggregation in the “map-reduce” model.

This currently only includes Mean, but we would welcome contributions of other aggregation functions such as Sum, Std, Var, Min, Max, etc.

High-level API

The Mean transformation comes in three forms: Mean, Mean.Globally, and Mean.PerKey. The implementation is highly scalable, based on a Beam’s CombineFn.

The high-level Mean transform can be used to aggregate a distributed dataset across an existing dimension or dimensions, similar to Xarray’s .mean() method:

import apache_beam as beam
import numpy as np
import xarray_beam as xbeam
import xarray

ds = xarray.tutorial.load_dataset('air_temperature')
print(ds)
<xarray.Dataset>
Dimensions:  (lat: 25, time: 2920, lon: 53)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat, lon) float32 241.2 242.5 243.5 ... 296.5 296.2 295.7
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
with beam.Pipeline() as p:
    p | xbeam.DatasetToChunks(ds, chunks={'time': 1000}) | xbeam.Mean('time') | beam.Map(print)
(Key(offsets={'lat': 0, 'lon': 0}, vars=None), <xarray.Dataset>
Dimensions:  (lat: 25, lon: 53)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
Data variables:
    air      (lat, lon) float64 260.4 260.2 259.9 259.5 ... 297.3 297.3 297.3)

Lower-level API

Xarray-Beam also includes lower-level transforations modelled off of beam.Mean rather than xarray.Dataset.mean(): they compute averages over sequences of xarray.Dataset objects or (key, xarray.Dataset) pairs, rather than calculating an average over an existing Xarray dimension or based on xarray_beam.Key objects, e.g.,

datasets = [
    xarray.Dataset({'foo': ('x', np.random.randn(3))})
    for _ in range(100)
]
datasets | xbeam.Mean.Globally()
[<xarray.Dataset>
 Dimensions:  (x: 3)
 Dimensions without coordinates: x
 Data variables:
     foo      (x) float64 0.003679 0.3131 0.03719]

Notice how existing dimensions on each datasets are unchanged by the transformation. If you want to average over existing dimensions, use the high-level Mean transform or do that aggregation yourself, e.g., by averaging inside each chunk before combining the data.

Similarly, the keys fed into xbeam.Mean.PerKey can be any hashables, including but not limited to xbeam.Key:

datasets = [
    (time.dt.season.item(), ds.sel(time=time).mean())
    for time in ds.time
]
datasets | xbeam.Mean.PerKey()
[('DJF',
  <xarray.Dataset>
  Dimensions:  ()
  Data variables:
      air      float64 273.6),
 ('MAM',
  <xarray.Dataset>
  Dimensions:  ()
  Data variables:
      air      float64 279.0),
 ('JJA',
  <xarray.Dataset>
  Dimensions:  ()
  Data variables:
      air      float64 289.2),
 ('SON',
  <xarray.Dataset>
  Dimensions:  ()
  Data variables:
      air      float64 283.0)]

Mean.PerKey is particularly useful in combination with beam.GroupByKey for performing large-scale “group by” operations. For example, that a look at the ERA5 climatology example.

Custom aggregations

The “tree reduction” algorithm used by the combiner inside Mean is great, but it isn’t the only way to aggregate a dataset with Xarray-Beam.

In many cases, the easiest way to scale up an aggregation pipeline is to make use of rechunking to convert the many small datasets inside your pipeline into a form that is easier to calculate in a scalable way. However, rechunking is much less efficient than using combiner, because each use of Rechunk requires a complete shuffle of the input data (i.e., writing all data in the pipepilne to temporary files on disk).

For example, here’s how one could compute the median, which is a notoriously difficult statistic to calculate with distributed algorithms:

source_chunks = {'time': 100, 'lat': -1, 'lon': -1}
working_chunks = {'lat': 10, 'lon': 10, 'time': -1}

with beam.Pipeline() as p:
    (
        p
        | xbeam.DatasetToChunks(ds, source_chunks)
        | xbeam.Rechunk(ds.sizes, source_chunks, working_chunks, itemsize=4)
        | beam.MapTuple(lambda k, v: (k.with_offsets(time=None), v.median('time')))
        | xbeam.ConsolidateChunks({'lat': -1, 'lon': -1})
        | beam.MapTuple(lambda k, v: print(v))
    )
<xarray.Dataset>
Dimensions:  (lat: 25, lon: 53)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
Data variables:
    air      (lat, lon) float32 261.3 261.1 260.9 260.3 ... 297.3 297.3 297.3