API Reference
The API of jz-tree. Only functions that may be relevant for interfacing and calling the code are listed. For developers we recommend to directly consider the source code.
jztree.config
This module contains config dataclasses that bundle arguments that modify low-level details of the code. Relevant configs are always passed through function arguments and trigger jit-recompilation if modified. The main reasons to pass non-default configuartions is if some allocation turned out too small, or to optimize memory usage.
- class jztree.config.RegularizationConfig(regularize_percentile=90.0, max_volume_fac=20.0)
Config controlling regularization.
- Parameters:
regularize_percentile (float)
max_volume_fac (float)
- class jztree.config.TreeConfig(max_leaf_size=32, coarse_fac=6.0, stop_coarsen=1024, regularization=None, alloc_fac_nodes=1.0, nsamp=1024, mass_centered=False)
Config dataclass controlling tree structure and allocation.
- Parameters:
max_leaf_size (int)
coarse_fac (float)
stop_coarsen (int)
regularization (RegularizationConfig | None)
alloc_fac_nodes (float)
nsamp (int)
mass_centered (bool)
- class jztree.config.FofCatalogueConfig(npart_min=20)
Config dataclass controlling aspects of friends-of-friends catalogues
- Parameters:
npart_min (int)
- class jztree.config.FofConfig(alloc_fac_ilist=32.0, alloc_fac_distr_links=0.01, tree=TreeConfig(max_leaf_size=48, coarse_fac=8.0, stop_coarsen=2048, regularization=None, alloc_fac_nodes=1.1, nsamp=1024, mass_centered=False), catalogue=FofCatalogueConfig(npart_min=20))
Main nested config for friends-of-friends
- Parameters:
alloc_fac_ilist (float)
alloc_fac_distr_links (float)
tree (TreeConfig)
catalogue (FofCatalogueConfig)
- class jztree.config.KNNConfig(alloc_fac_ilist=256.0, tree=TreeConfig(max_leaf_size=48, coarse_fac=8.0, stop_coarsen=2048, regularization=RegularizationConfig(regularize_percentile=90.0, max_volume_fac=20.0), alloc_fac_nodes=1.0, nsamp=1024, mass_centered=False))
Main nested config for k nearest neighbour search
- Parameters:
alloc_fac_ilist (float)
tree (TreeConfig)
jztree.data
This module contains some dataclasses that define interfaces between different parts of the code.
In particular relevant are the particle data classes, like Pos that are needed in
multi-GPU setups, to keep track of particle counts.
- class jztree.data.Pos(*, pos, num=None, num_total=None)
A dataclass holding positions.
numandnum_totalare only required for multi-GPU setups where the number of filled entries may be less than the size and run-time dependent- Parameters:
pos (object) – Position array of shape (size, dim)
num (int | object | None) – Filled count (<=size) on local device, needed for multi-GPU, may be dynamic
num_total (int | None) – Total count on all devices, static
- class jztree.data.PosMass(*, pos, mass, num=None, num_total=None)
Dataclass holding positions and masses, see
Posfor details- Parameters:
pos (object)
mass (object)
num (object | None)
num_total (int | None)
- class jztree.data.ParticleData(*, pos, mass=None, vel=None, id=None, num=None, num_total=None)
Dataclass holding positions and optional data, see
Posfor details- Parameters:
pos (object)
mass (object | None)
vel (object | None)
id (object | None)
num (object | None)
num_total (int | None)
- class jztree.data.PackedArray(data, ispl, fill_values=None, levels_filled=<factory>)
A dataclass allowing to stack several dynamic arrays into a single buffer
- Parameters:
data (object)
ispl (object)
fill_values (object | None)
levels_filled (object)
- class jztree.data.LevelInfo(dim, dtype)
Data class holding info on minimum and maximum Morton level of a tree
- Parameters:
dim (int)
dtype (dtype)
- class jztree.data.TreeHierarchy(size_leaves, ispl_n2n, ispl_n2l, ispl_l2p_per_type, lvl, geom_cent, mass=None, mass_cent=None)
Dataclass holding the tree-plane hierarchy
Most properties are stored in
PackedArrayclasses that stack them contingously in a single buffer for all tree-planes- Parameters:
size_leaves (int) – Size of the leaf-node level allocation
ispl_n2n (PackedArray) – The node-to-node splitting points. Node i on plane p includes all plane p-1 nodes in the range ispl_n2n.get(p)[i] … ispl_n2n.get(p)[i+1]
ispl_n2l (PackedArray) – Splitting points that go directly from nodes to leaves.
ispl_l2p_per_type (List[object]) – Leaf-to-particle splitting points per particle type.
lvl (PackedArray) – Morton-level of nodes. Can be used to define extent of nodes.
geom_cent (PackedArray) – Geometric centers
mass (PackedArray | None) – Mass of nodes (only available for mass_centered tree)
mass_cent (PackedArray | None) – Mass centers of nodes (only available for mass_centered tree)
- splits_leaf_to_part(ptype=0, size=None)
Convenience method to get the leaf-to-particle splits for a specific type
- Parameters:
ptype (int)
size (int | None)
- Return type:
object
- npart(level, ptype=0, size=None)
Number of particles in an node
- Parameters:
level (int)
ptype (int)
- Return type:
object
- center()
Returns mass_cent or geom_cent, depending on how tree was constructed
- Return type:
- num_planes()
Numer of planes in the tree
- Return type:
int
- num(level)
Number of nodes in a given plane
- Return type:
int
- size()
The recommended allocation size at the leaf level
- Return type:
int
- class jztree.data.InteractionList(ispl, isrc, rad2=None, ids=None, dev_spl=None)
A dataclass that holds interaction information for dual tree-walks
The interaction list is defined so that a receiving node i needs to interact with all source nodes in the range isrc[ispl[i]:ispl[i+1]]
- Parameters:
ispl (object) – splitting points of interaction list segments
isrc (object) – interaction source indices
rad2 (object | None) – interaction radii squared. Optional – so far only used in kNN algorithm
ids (object | None) – used in multi-GPU scenarios to define for each source index the origin index
dev_spl (object | None) – used in multi-GPU scenarios to define for each unique source index the origin rank. (Sources in range dev_spl[r]:dev_spl[r+1] belong to rank r.)
- without_remote_query_points(rank)
By default the interaction list carries remote and local points both on ispl so that query and source indices are consistent. However, this function redefines the interaction list so that ispl is only defined for local query points, but interaction indices may still point to remote points
- Parameters:
rank (int)
- Return type:
- nfilled()
Number of filled entries
- size()
Size of the interaction indices array
- dtype()
Index datatype (so far always int32)
- class jztree.data.Label(irank, igroup, ilocal_segment=None)
A FoF-group label for multi-GPU, pointing a root particle’s rank and index
- Parameters:
irank (object)
igroup (object)
ilocal_segment (object | None)
- class jztree.data.FofCatalogue(ngroups, mass=None, count=None, offset=None, com_pos=None, com_vel=None, com_inertia_radius=None, scale_factor=None, v_rad=None, offset_rank=None)
Friends-of-friends catalogue data
Many properties are optional and will be set to None if calculation was lacking data. To allow jax.jit compatibility the arrays are generally allocated larger than what is actually needed, with
ngroupsindicating the filled count. You may usesqueeze_catalogue()(outside of jax.jit) to obtain a squeezed version.- Parameters:
ngroups (object) – Actual number of groups. (On multi-GPU the local count.)
mass (object | None) – Group masses
count (object | None) – Number of particles
offset (object | None) – Starting point of each group in the particle array
com_pos (object | None) – Center of mass position
com_vel (object | None) – Center of mass velocity
com_inertia_radius (object | None) – Inertia radius \(\sqrt{\langle (\mathbf{x} - \mathbf{x}_0)^2 \rangle}\).
scale_factor (object | None) – For light-cone data: scale factor of light-cone crossing.
v_rad (object | None) – For light-cone data: radial (line of sight) velocity
offset_rank (object | None) – Provided for squeezed catalogues to indicate origin rank.
- class jztree.data.RankIdx(rank, idx)
Holds a rank and an index – used to point to a particle in multi-GPU setups.
- Parameters:
rank (object)
idx (object)
- jztree.data.expand_particles(part, ndev)
Expands particles from shape (Ndev*N) -> (Ndev,N)
Useful to interface with shard maps, since (Ndev,N) shape is assumed in general
- Parameters:
part (Pos)
ndev (int)
- jztree.data.flatten_particles(part)
Flattens particles from shape (Ndev,N) -> (Ndev*N).
- Parameters:
part (Pos)
- jztree.data.pad_particles(part, num, float_val=jnp.nan, int_val=0)
Pads particle data (to leave space for communication).
- Parameters:
part (Pos)
num (int)
float_val (float)
int_val (int)
- jztree.data.squeeze_catalogue(cata, size_out=None, offset_mode='rank', nparts=None)
Squeezes multi-GPU FoF catalogue that was returned from a shard_map (Ndev,size_group) into a dense form (Ntot) and replicates it on every device
- Parameters:
cata (FofCatalogue) – FoF catalogue
size_out (int | None) – Can be provided to allow jit compatibility
offset_mode (str) – Can be “rank” or “flat”. Before squeezing offsets indicate locations in the particle array of the same device. Since squeezing looses the device info, we need to indicate either the rank or we need to convert offsets to global offsets that index a squeezed particle array (converts to int64). “global” needs “nparts” as an input, i.e. how many particles were on each device
nparts (object | None) – particles that are on each device, needed for “global” offset_mode
- Return type:
jztree.tree
This module contains functions for sorting particles into z-order, for building a plane-based tree hierarchy and for defining interaction lists.
- jztree.tree.zsort(pos)
Brings position vectors into z-order
If x is a pytree, it needs to have a “pos” attribute which will be used as the sorting key. All remaining leaves of the pytrees with the same length will be sorted accordingly along the leading axis
- Parameters:
pos (object | Pos) – Can be a jax.Array or a more complicated particle data structure following the
jztree.data.Posinterface. All pytree-leaves with the same leading shape as the .pos atribute will be sorted.- Returns:
(posz, idz) – Sorted positions and ids so that posz = pos[idz]
- Return type:
Tuple[object | Pos, object]
- jztree.tree.search_sorted_z(posz, posz_query, block_size=64, leaf_search=False)
Finds the indices in xz where elements of xz_query would be inserted to keep order.
This is similar to np.searchsorted, but works for points sorted in Z-order. On equality maintains the rule: xz[idx] < v <= xz[idx+1] if leaf_search is True, it is assumed that xz contains one point per leaf and we return the index of the leaf that the query point belongs to.
- Parameters:
posz (object)
posz_query (object)
block_size (int)
- jztree.tree.zsort_and_tree(part, cfg_tree=TreeConfig(), data=None, ptype=None, num_types=None, shrink=True)
Brings particles into z-order and creates a tree hierarchy.
In the multi-GPU scenario, particles may be communicated and their balance adjusted to ensure that top-level nodes don’t cross domain boundaries.
For multi-type trees, please use
zsort_and_tree_multi_type()- Parameters:
part (Pos) – position array or dataclass
cfg_tree (TreeConfig) – low-level configuration options
data (Any | None) – optional: additional particle data that will be carried along through sorting and communication. Can be a pytree and will be provided as an additional output
ptype (object | None) – optional, a type-index per particle to distinguish particle types in tree
num_types (int | None) – number of particle types
shrink (bool) – if False, always return 3 three outputs
- Return type:
Tuple[Pos, TreeHierarchy]
- jztree.tree.zsort_and_tree_multi_type(part, cfg_tree=TreeConfig(), data=None)
Builds a tree with multiple particle types
The tree’s leaf-to-particle splits
jztree.data.TreeHierarchy.ispl_l2p_per_typewill be defined separately for each type.- Parameters:
part (Tuple[Pos, ...]) – tuple of particle arrays of different types
cfg_tree (TreeConfig) – config object controlling tree structure and allocation
data (Any | None) – additional particle data to be carried along
- Returns:
(partz, th) – sorted particles and tree hierarchy. If
datais provided additionally returns the sorted data.- Return type:
Tuple[Tuple[Pos, …], TreeHierarchy]
- jztree.tree.build_tree_hierarchy(partz, cfg_tree, lvl_bound=None, ptype=None, num_types=None)
Builds a tree hierarchy from z-order positions
The zeroth level of the tree corresponds to leaves, which contain multiple particles. Nodes (and leaves) are selected so that they are as big as possible while not containing more than a maximum number of particles that starts at cfg_tree.max_leaf_size and increases per level by a factor cfg_tree.coarse_fac.
Nodes are parameterized through a set of splits. For example the particles that lie in the leaf with index i are given py part[ispl_n2n.get(level=0)[i]: ispl_n2n.get(level=0)[i+1]] The ith node of level n contain all level n-1 nodes in the range : ispl_n2n.get(n)[i]: ispl_n2n.get(n)[i+1]
In jax memory size needs to be known at compile time, but the required number of nodes is data dependent on each level. To limit the number of allocations that we need to predict, we use the PackedArray class, that helps us to stack multiple different levels into a single continguous array, but to access it “almost” as if they were separate arrays.
It is recommend to use the
zsort_and_tree()interface – especially for Multi-GPU, since the particle distribution needs to be adjusted to acommodate nodes.- Parameters:
partz (PosMass | object)
cfg_tree (TreeConfig)
lvl_bound (object | None)
ptype (object | None)
num_types (object | None)
- Return type:
- jztree.tree.grouped_dense_interaction_list(nnodes, size_ilist, ngroup=32, size_super=None, node_range=None, dtype=jnp.int32)
Defines an all-to-all interaction list over super-nodes and a super-node to node relation
This is useful for evaluating all-to-all interactions in a grouped manner on GPU
- Parameters:
nnodes (object | int) – number of top-level nodes
size_ilist (int) – allocation size of the interaction list
ngroup (int) – how many top-level nodes should be summarized into one super node
size_super (int | None) – size of the super-node allocation
node_range (object | None) – can be provided to only include receiving nodes in this range, (but source nodes on the full domain. Useful for multi-GPU scenarios)
dtype – data-dtype of the interaction list. So far only int32 supported.
- Returns:
A tuple (spl_super, ilist, nsuper_nodes) containing the super node splitting points, the interaction list and the number of super nodes.
- Return type:
Tuple[object, InteractionList, object]
- jztree.tree.simplify_interaction_list(ilist, always_keep=None)
Reduces an interaction list to skip nodes that don’t appear as receiving our source indices
Useful in multi-GPU scenarios where many non-local nodes will not have any local interactions
- Parameters:
ilist (InteractionList)
always_keep (object | None)
- Return type:
jztree.knn
This module contains functions for doing k-nearest neighbour search.
- jztree.knn.knn(part, k, boxsize=None, th=None, part_query=None, result='rad_globalidx', reduce_func=None, output_order='input', cfg=KNNConfig())
The main function executing the nearest neighbour search.
By default this returns (rnn, inn) – the radii and indices of the k nearest neighbours. Arguments may modify returned results and their ordering.
- Parameters:
part (object | Pos) – Source positions, may be a jax.Array or a more complex dataclass following the
jztree.data.Posinterface.k (int) – Number of neighbours to calculate per particle
boxsize (float | None) – If provided, distance calculations are wrapped periodically.
th (TreeHierarchy | None) – May be provided to skip building a new tree-hierarchy inside of this function. See
jztree.tree.zsort_and_tree(). If this argument is provided, particles are assumed to be already in z-order.part_query (object | Pos | None) – Query positions, defaults to
part.result (str | object | Any) – String indicating the desired return values. May have any of the following separated by underscores: “rad” (radii), “drad” (differentible radii), “rankidx” (origin rank and index
jztree.data.RankIdx), “globalidx” (linear global index), “part” (mapped source particles), “reduce” (seereduce_func) or the name of any attribute on the particle data structure.reduce_func (Callable | None) – Will be called with the neighbour list in z-order to get a summary statistic per query particle. Useful to avoid communicating the neighbour list to origin tasks in distributed setups. Requires “reduce” inside of
result.output_order (str) – May be “input” or “z”.
cfg (KNNConfig) – Config object that controls lower-level details of the algorithm
- Returns:
tuple (rnn, inn) with neighbour radii and indices
- Return type:
Tuple[object, object]
jztree.fof
This module contains functions to calculate friends-of-friends (FoF) group labels, for bringing particles into group order and for calculating a group catalogue.
- jztree.fof.fof_is_superset(igroup_sup, igroup, mask=None)
Checks whether every FoF group in igroup_sup is a superset of sets in igroup_low
Useful for comparing different FoF implementations where labels may differ even when they represent the same groups.
- jztree.fof.fof_labels(part, rlink, boxsize=0., cfg=FofConfig(), th=None, output_order='z')
Calculates the friends-of-friends group relation ship for single-GPU cases
For multi-GPU, please use
distr_fof_labels()- Parameters:
part (object | Pos) – particles, can be a jax.Array or any particle data structure with a .pos atribute
rlink (float) – (absolute) linking length
boxsize (float) – if provided, distances use periodic wrapping
cfg (FofConfig) – controls low-level details
th (TreeHierarchy | None) – May be provided to skip building a new tree-hierarchy inside of this function. See
jztree.tree.zsort_and_tree(). If this argument is provided, particles are assumed to be already in z-order.output_order (str) – may be “z” or “input”
- Returns:
(part, igroup) – (possibly) reordered particles and group labels – igroup points to the first particle in z-order that belongs to the same group.
- Return type:
Tuple[object | Pos, object]
- jztree.fof.distr_fof_labels(part, rlink, boxsize=0., cfg=FofConfig(), th=None, linearize_labels=False)
Calculates the friends-of-friends group relation ship for multi-GPU cases.
For single-GPU, please use
fof_labels()- Parameters:
part (Pos) – particles, should have some padding and follow the
jztree.data.Posinterfacerlink (float) – (absolute) linking length
boxsize (float) – if provided, distances use periodic wrapping
cfg (FofConfig) – controls low-level details
th (TreeHierarchy | None) – May be provided to skip building a new tree-hierarchy inside of this function. See
jztree.tree.zsort_and_tree(). If this argument is provided, particles are assumed to be already in z-order.linearize_labels (bool) – if true, group labels will be output as a scalar global index. Only use this for comparing small setups against single-GPU.
- Returns:
(partz, label) – z-ordered particles and group labels – label points to the rank and index of the first particle in z-order that belongs to the same group.
- Return type:
- jztree.fof.fof_and_catalogue(part, rlink, boxsize=0., cfg=FofConfig(), th=None)
Puts particles in FoF-order and calculates the FoF-catalogue
Supports single- and multi-GPU.
- Parameters:
part (ParticleData) – particles, should follow at least the interface of
jztree.data.Posand may optionally contain additional elments as injztree.data.ParticleDatafor calculating additional data in the catalogue. For multi-GPU, particles need to be padded to allow space for communication.rlink (float) – (absolute) linking length
boxsize (float) – if provided, distances use periodic wrapping
cfg (FofConfig) – controls low-level details
th (TreeHierarchy | None) – May be provided to skip building a new tree-hierarchy inside of this function. See
jztree.tree.zsort_and_tree(). If this argument is provided, particles are assumed to be already in z-order.
- Returns:
(partf, catalogue) – particles in FoF-order and the group catalogue. Each group forms a continous segment in the particle array, but the last group on each rank may continue on the next rank.
- Return type:
Tuple[ParticleData, FofCatalogue]