API Reference

The API of jz-tree. Only functions that may be relevant for interfacing and calling the code are listed. For developers we recommend to directly consider the source code.

jztree.config

This module contains config dataclasses that bundle arguments that modify low-level details of the code. Relevant configs are always passed through function arguments and trigger jit-recompilation if modified. The main reasons to pass non-default configuartions is if some allocation turned out too small, or to optimize memory usage.

class jztree.config.RegularizationConfig(regularize_percentile=90.0, max_volume_fac=20.0)

Config controlling regularization.

Parameters:
  • regularize_percentile (float)

  • max_volume_fac (float)

class jztree.config.TreeConfig(max_leaf_size=32, coarse_fac=6.0, stop_coarsen=1024, regularization=None, alloc_fac_nodes=1.0, nsamp=1024, mass_centered=False)

Config dataclass controlling tree structure and allocation.

Parameters:
  • max_leaf_size (int)

  • coarse_fac (float)

  • stop_coarsen (int)

  • regularization (RegularizationConfig | None)

  • alloc_fac_nodes (float)

  • nsamp (int)

  • mass_centered (bool)

class jztree.config.FofCatalogueConfig(npart_min=20)

Config dataclass controlling aspects of friends-of-friends catalogues

Parameters:

npart_min (int)

class jztree.config.FofConfig(alloc_fac_ilist=32.0, alloc_fac_distr_links=0.01, tree=TreeConfig(max_leaf_size=48, coarse_fac=8.0, stop_coarsen=2048, regularization=None, alloc_fac_nodes=1.1, nsamp=1024, mass_centered=False), catalogue=FofCatalogueConfig(npart_min=20))

Main nested config for friends-of-friends

Parameters:
class jztree.config.KNNConfig(alloc_fac_ilist=256.0, tree=TreeConfig(max_leaf_size=48, coarse_fac=8.0, stop_coarsen=2048, regularization=RegularizationConfig(regularize_percentile=90.0, max_volume_fac=20.0), alloc_fac_nodes=1.0, nsamp=1024, mass_centered=False))

Main nested config for k nearest neighbour search

Parameters:

jztree.data

This module contains some dataclasses that define interfaces between different parts of the code. In particular relevant are the particle data classes, like Pos that are needed in multi-GPU setups, to keep track of particle counts.

class jztree.data.Pos(*, pos, num=None, num_total=None)

A dataclass holding positions.

num and num_total are only required for multi-GPU setups where the number of filled entries may be less than the size and run-time dependent

Parameters:
  • pos (object) – Position array of shape (size, dim)

  • num (int | object | None) – Filled count (<=size) on local device, needed for multi-GPU, may be dynamic

  • num_total (int | None) – Total count on all devices, static

class jztree.data.PosMass(*, pos, mass, num=None, num_total=None)

Dataclass holding positions and masses, see Pos for details

Parameters:
  • pos (object)

  • mass (object)

  • num (object | None)

  • num_total (int | None)

class jztree.data.ParticleData(*, pos, mass=None, vel=None, id=None, num=None, num_total=None)

Dataclass holding positions and optional data, see Pos for details

Parameters:
  • pos (object)

  • mass (object | None)

  • vel (object | None)

  • id (object | None)

  • num (object | None)

  • num_total (int | None)

class jztree.data.PackedArray(data, ispl, fill_values=None, levels_filled=<factory>)

A dataclass allowing to stack several dynamic arrays into a single buffer

Parameters:
  • data (object)

  • ispl (object)

  • fill_values (object | None)

  • levels_filled (object)

class jztree.data.LevelInfo(dim, dtype)

Data class holding info on minimum and maximum Morton level of a tree

Parameters:
  • dim (int)

  • dtype (dtype)

class jztree.data.TreeHierarchy(size_leaves, ispl_n2n, ispl_n2l, ispl_l2p_per_type, lvl, geom_cent, mass=None, mass_cent=None)

Dataclass holding the tree-plane hierarchy

Most properties are stored in PackedArray classes that stack them contingously in a single buffer for all tree-planes

Parameters:
  • size_leaves (int) – Size of the leaf-node level allocation

  • ispl_n2n (PackedArray) – The node-to-node splitting points. Node i on plane p includes all plane p-1 nodes in the range ispl_n2n.get(p)[i] … ispl_n2n.get(p)[i+1]

  • ispl_n2l (PackedArray) – Splitting points that go directly from nodes to leaves.

  • ispl_l2p_per_type (List[object]) – Leaf-to-particle splitting points per particle type.

  • lvl (PackedArray) – Morton-level of nodes. Can be used to define extent of nodes.

  • geom_cent (PackedArray) – Geometric centers

  • mass (PackedArray | None) – Mass of nodes (only available for mass_centered tree)

  • mass_cent (PackedArray | None) – Mass centers of nodes (only available for mass_centered tree)

splits_leaf_to_part(ptype=0, size=None)

Convenience method to get the leaf-to-particle splits for a specific type

Parameters:
  • ptype (int)

  • size (int | None)

Return type:

object

npart(level, ptype=0, size=None)

Number of particles in an node

Parameters:
  • level (int)

  • ptype (int)

Return type:

object

center()

Returns mass_cent or geom_cent, depending on how tree was constructed

Return type:

PackedArray

num_planes()

Numer of planes in the tree

Return type:

int

num(level)

Number of nodes in a given plane

Return type:

int

size()

The recommended allocation size at the leaf level

Return type:

int

info()

Info about the minimum and maximum Morton levels of the tree

Return type:

LevelInfo

class jztree.data.InteractionList(ispl, isrc, rad2=None, ids=None, dev_spl=None)

A dataclass that holds interaction information for dual tree-walks

The interaction list is defined so that a receiving node i needs to interact with all source nodes in the range isrc[ispl[i]:ispl[i+1]]

Parameters:
  • ispl (object) – splitting points of interaction list segments

  • isrc (object) – interaction source indices

  • rad2 (object | None) – interaction radii squared. Optional – so far only used in kNN algorithm

  • ids (object | None) – used in multi-GPU scenarios to define for each source index the origin index

  • dev_spl (object | None) – used in multi-GPU scenarios to define for each unique source index the origin rank. (Sources in range dev_spl[r]:dev_spl[r+1] belong to rank r.)

without_remote_query_points(rank)

By default the interaction list carries remote and local points both on ispl so that query and source indices are consistent. However, this function redefines the interaction list so that ispl is only defined for local query points, but interaction indices may still point to remote points

Parameters:

rank (int)

Return type:

InteractionList

nfilled()

Number of filled entries

size()

Size of the interaction indices array

dtype()

Index datatype (so far always int32)

class jztree.data.Label(irank, igroup, ilocal_segment=None)

A FoF-group label for multi-GPU, pointing a root particle’s rank and index

Parameters:
  • irank (object)

  • igroup (object)

  • ilocal_segment (object | None)

class jztree.data.FofCatalogue(ngroups, mass=None, count=None, offset=None, com_pos=None, com_vel=None, com_inertia_radius=None, scale_factor=None, v_rad=None, offset_rank=None)

Friends-of-friends catalogue data

Many properties are optional and will be set to None if calculation was lacking data. To allow jax.jit compatibility the arrays are generally allocated larger than what is actually needed, with ngroups indicating the filled count. You may use squeeze_catalogue() (outside of jax.jit) to obtain a squeezed version.

Parameters:
  • ngroups (object) – Actual number of groups. (On multi-GPU the local count.)

  • mass (object | None) – Group masses

  • count (object | None) – Number of particles

  • offset (object | None) – Starting point of each group in the particle array

  • com_pos (object | None) – Center of mass position

  • com_vel (object | None) – Center of mass velocity

  • com_inertia_radius (object | None) – Inertia radius \(\sqrt{\langle (\mathbf{x} - \mathbf{x}_0)^2 \rangle}\).

  • scale_factor (object | None) – For light-cone data: scale factor of light-cone crossing.

  • v_rad (object | None) – For light-cone data: radial (line of sight) velocity

  • offset_rank (object | None) – Provided for squeezed catalogues to indicate origin rank.

class jztree.data.RankIdx(rank, idx)

Holds a rank and an index – used to point to a particle in multi-GPU setups.

Parameters:
  • rank (object)

  • idx (object)

jztree.data.expand_particles(part, ndev)

Expands particles from shape (Ndev*N) -> (Ndev,N)

Useful to interface with shard maps, since (Ndev,N) shape is assumed in general

Parameters:
  • part (Pos)

  • ndev (int)

jztree.data.flatten_particles(part)

Flattens particles from shape (Ndev,N) -> (Ndev*N).

Parameters:

part (Pos)

jztree.data.pad_particles(part, num, float_val=jnp.nan, int_val=0)

Pads particle data (to leave space for communication).

Parameters:
  • part (Pos)

  • num (int)

  • float_val (float)

  • int_val (int)

jztree.data.squeeze_catalogue(cata, size_out=None, offset_mode='rank', nparts=None)

Squeezes multi-GPU FoF catalogue that was returned from a shard_map (Ndev,size_group) into a dense form (Ntot) and replicates it on every device

Parameters:
  • cata (FofCatalogue) – FoF catalogue

  • size_out (int | None) – Can be provided to allow jit compatibility

  • offset_mode (str) – Can be “rank” or “flat”. Before squeezing offsets indicate locations in the particle array of the same device. Since squeezing looses the device info, we need to indicate either the rank or we need to convert offsets to global offsets that index a squeezed particle array (converts to int64). “global” needs “nparts” as an input, i.e. how many particles were on each device

  • nparts (object | None) – particles that are on each device, needed for “global” offset_mode

Return type:

FofCatalogue

jztree.tree

This module contains functions for sorting particles into z-order, for building a plane-based tree hierarchy and for defining interaction lists.

jztree.tree.zsort(pos)

Brings position vectors into z-order

If x is a pytree, it needs to have a “pos” attribute which will be used as the sorting key. All remaining leaves of the pytrees with the same length will be sorted accordingly along the leading axis

Parameters:

pos (object | Pos) – Can be a jax.Array or a more complicated particle data structure following the jztree.data.Pos interface. All pytree-leaves with the same leading shape as the .pos atribute will be sorted.

Returns:

(posz, idz) – Sorted positions and ids so that posz = pos[idz]

Return type:

Tuple[object | Pos, object]

jztree.tree.search_sorted_z(posz, posz_query, block_size=64, leaf_search=False)

Finds the indices in xz where elements of xz_query would be inserted to keep order.

This is similar to np.searchsorted, but works for points sorted in Z-order. On equality maintains the rule: xz[idx] < v <= xz[idx+1] if leaf_search is True, it is assumed that xz contains one point per leaf and we return the index of the leaf that the query point belongs to.

Parameters:
  • posz (object)

  • posz_query (object)

  • block_size (int)

jztree.tree.zsort_and_tree(part, cfg_tree=TreeConfig(), data=None, ptype=None, num_types=None, shrink=True)

Brings particles into z-order and creates a tree hierarchy.

In the multi-GPU scenario, particles may be communicated and their balance adjusted to ensure that top-level nodes don’t cross domain boundaries.

For multi-type trees, please use zsort_and_tree_multi_type()

Parameters:
  • part (Pos) – position array or dataclass

  • cfg_tree (TreeConfig) – low-level configuration options

  • data (Any | None) – optional: additional particle data that will be carried along through sorting and communication. Can be a pytree and will be provided as an additional output

  • ptype (object | None) – optional, a type-index per particle to distinguish particle types in tree

  • num_types (int | None) – number of particle types

  • shrink (bool) – if False, always return 3 three outputs

Return type:

Tuple[Pos, TreeHierarchy]

jztree.tree.zsort_and_tree_multi_type(part, cfg_tree=TreeConfig(), data=None)

Builds a tree with multiple particle types

The tree’s leaf-to-particle splits jztree.data.TreeHierarchy.ispl_l2p_per_type will be defined separately for each type.

Parameters:
  • part (Tuple[Pos, ...]) – tuple of particle arrays of different types

  • cfg_tree (TreeConfig) – config object controlling tree structure and allocation

  • data (Any | None) – additional particle data to be carried along

Returns:

(partz, th) – sorted particles and tree hierarchy. If data is provided additionally returns the sorted data.

Return type:

Tuple[Tuple[Pos, …], TreeHierarchy]

jztree.tree.build_tree_hierarchy(partz, cfg_tree, lvl_bound=None, ptype=None, num_types=None)

Builds a tree hierarchy from z-order positions

The zeroth level of the tree corresponds to leaves, which contain multiple particles. Nodes (and leaves) are selected so that they are as big as possible while not containing more than a maximum number of particles that starts at cfg_tree.max_leaf_size and increases per level by a factor cfg_tree.coarse_fac.

Nodes are parameterized through a set of splits. For example the particles that lie in the leaf with index i are given py part[ispl_n2n.get(level=0)[i]: ispl_n2n.get(level=0)[i+1]] The ith node of level n contain all level n-1 nodes in the range : ispl_n2n.get(n)[i]: ispl_n2n.get(n)[i+1]

In jax memory size needs to be known at compile time, but the required number of nodes is data dependent on each level. To limit the number of allocations that we need to predict, we use the PackedArray class, that helps us to stack multiple different levels into a single continguous array, but to access it “almost” as if they were separate arrays.

It is recommend to use the zsort_and_tree() interface – especially for Multi-GPU, since the particle distribution needs to be adjusted to acommodate nodes.

Parameters:
  • partz (PosMass | object)

  • cfg_tree (TreeConfig)

  • lvl_bound (object | None)

  • ptype (object | None)

  • num_types (object | None)

Return type:

TreeHierarchy

jztree.tree.grouped_dense_interaction_list(nnodes, size_ilist, ngroup=32, size_super=None, node_range=None, dtype=jnp.int32)

Defines an all-to-all interaction list over super-nodes and a super-node to node relation

This is useful for evaluating all-to-all interactions in a grouped manner on GPU

Parameters:
  • nnodes (object | int) – number of top-level nodes

  • size_ilist (int) – allocation size of the interaction list

  • ngroup (int) – how many top-level nodes should be summarized into one super node

  • size_super (int | None) – size of the super-node allocation

  • node_range (object | None) – can be provided to only include receiving nodes in this range, (but source nodes on the full domain. Useful for multi-GPU scenarios)

  • dtype – data-dtype of the interaction list. So far only int32 supported.

Returns:

A tuple (spl_super, ilist, nsuper_nodes) containing the super node splitting points, the interaction list and the number of super nodes.

Return type:

Tuple[object, InteractionList, object]

jztree.tree.simplify_interaction_list(ilist, always_keep=None)

Reduces an interaction list to skip nodes that don’t appear as receiving our source indices

Useful in multi-GPU scenarios where many non-local nodes will not have any local interactions

Parameters:
Return type:

InteractionList

jztree.knn

This module contains functions for doing k-nearest neighbour search.

jztree.knn.knn(part, k, boxsize=None, th=None, part_query=None, result='rad_globalidx', reduce_func=None, output_order='input', cfg=KNNConfig())

The main function executing the nearest neighbour search.

By default this returns (rnn, inn) – the radii and indices of the k nearest neighbours. Arguments may modify returned results and their ordering.

Parameters:
  • part (object | Pos) – Source positions, may be a jax.Array or a more complex dataclass following the jztree.data.Pos interface.

  • k (int) – Number of neighbours to calculate per particle

  • boxsize (float | None) – If provided, distance calculations are wrapped periodically.

  • th (TreeHierarchy | None) – May be provided to skip building a new tree-hierarchy inside of this function. See jztree.tree.zsort_and_tree(). If this argument is provided, particles are assumed to be already in z-order.

  • part_query (object | Pos | None) – Query positions, defaults to part.

  • result (str | object | Any) – String indicating the desired return values. May have any of the following separated by underscores: “rad” (radii), “drad” (differentible radii), “rankidx” (origin rank and index jztree.data.RankIdx), “globalidx” (linear global index), “part” (mapped source particles), “reduce” (see reduce_func) or the name of any attribute on the particle data structure.

  • reduce_func (Callable | None) – Will be called with the neighbour list in z-order to get a summary statistic per query particle. Useful to avoid communicating the neighbour list to origin tasks in distributed setups. Requires “reduce” inside of result.

  • output_order (str) – May be “input” or “z”.

  • cfg (KNNConfig) – Config object that controls lower-level details of the algorithm

Returns:

tuple (rnn, inn) with neighbour radii and indices

Return type:

Tuple[object, object]

jztree.fof

This module contains functions to calculate friends-of-friends (FoF) group labels, for bringing particles into group order and for calculating a group catalogue.

jztree.fof.fof_is_superset(igroup_sup, igroup, mask=None)

Checks whether every FoF group in igroup_sup is a superset of sets in igroup_low

Useful for comparing different FoF implementations where labels may differ even when they represent the same groups.

jztree.fof.fof_labels(part, rlink, boxsize=0., cfg=FofConfig(), th=None, output_order='z')

Calculates the friends-of-friends group relation ship for single-GPU cases

For multi-GPU, please use distr_fof_labels()

Parameters:
  • part (object | Pos) – particles, can be a jax.Array or any particle data structure with a .pos atribute

  • rlink (float) – (absolute) linking length

  • boxsize (float) – if provided, distances use periodic wrapping

  • cfg (FofConfig) – controls low-level details

  • th (TreeHierarchy | None) – May be provided to skip building a new tree-hierarchy inside of this function. See jztree.tree.zsort_and_tree(). If this argument is provided, particles are assumed to be already in z-order.

  • output_order (str) – may be “z” or “input”

Returns:

(part, igroup) – (possibly) reordered particles and group labels – igroup points to the first particle in z-order that belongs to the same group.

Return type:

Tuple[object | Pos, object]

jztree.fof.distr_fof_labels(part, rlink, boxsize=0., cfg=FofConfig(), th=None, linearize_labels=False)

Calculates the friends-of-friends group relation ship for multi-GPU cases.

For single-GPU, please use fof_labels()

Parameters:
  • part (Pos) – particles, should have some padding and follow the jztree.data.Pos interface

  • rlink (float) – (absolute) linking length

  • boxsize (float) – if provided, distances use periodic wrapping

  • cfg (FofConfig) – controls low-level details

  • th (TreeHierarchy | None) – May be provided to skip building a new tree-hierarchy inside of this function. See jztree.tree.zsort_and_tree(). If this argument is provided, particles are assumed to be already in z-order.

  • linearize_labels (bool) – if true, group labels will be output as a scalar global index. Only use this for comparing small setups against single-GPU.

Returns:

(partz, label) – z-ordered particles and group labels – label points to the rank and index of the first particle in z-order that belongs to the same group.

Return type:

Tuple[Pos, Label]

jztree.fof.fof_and_catalogue(part, rlink, boxsize=0., cfg=FofConfig(), th=None)

Puts particles in FoF-order and calculates the FoF-catalogue

Supports single- and multi-GPU.

Parameters:
  • part (ParticleData) – particles, should follow at least the interface of jztree.data.Pos and may optionally contain additional elments as in jztree.data.ParticleData for calculating additional data in the catalogue. For multi-GPU, particles need to be padded to allow space for communication.

  • rlink (float) – (absolute) linking length

  • boxsize (float) – if provided, distances use periodic wrapping

  • cfg (FofConfig) – controls low-level details

  • th (TreeHierarchy | None) – May be provided to skip building a new tree-hierarchy inside of this function. See jztree.tree.zsort_and_tree(). If this argument is provided, particles are assumed to be already in z-order.

Returns:

(partf, catalogue) – particles in FoF-order and the group catalogue. Each group forms a continous segment in the particle array, but the last group on each rank may continue on the next rank.

Return type:

Tuple[ParticleData, FofCatalogue]