Utils module

Utils module.

MARBLE.utils.print_settings(model)[source]

Print parameters to screen

MARBLE.utils.parallel_proc(fun, iterable, inputs, processes=-1, desc='')[source]

Distribute an iterable function between processes

MARBLE.utils.move_to_gpu(model, data, adjs=None)[source]

Move stuff to gpu

MARBLE.utils.detach_from_gpu(model, data, adjs=None)[source]

detach stuff from gpu

MARBLE.utils.to_SparseTensor(edge_index, size=None, value=None)[source]

Adjacency matrix as torch_sparse tensor

Parameters:
  • edge_index (2xE matrix) – edge indices

  • size – pair (rows,cols) giving the size of the matrix. The default is the largest node of the edge_index.

  • value – list of weights. The default is unit values.

Returns:

adjacency matrix in SparseTensor format

MARBLE.utils.np2torch(x, dtype=None)[source]

Convert numpy to torch

MARBLE.utils.to_list(x)[source]

Convert to list

MARBLE.utils.to_pandas(x, augment_time=True)[source]

Convert numpy to pandas

class MARBLE.utils.EdgeIndex(edge_index: Tensor, e_id: Tensor | None, size: Tuple[int, int])[source]

Edge Index.

Create new instance of EdgeIndex(edge_index, e_id, size)

edge_index: Tensor

Alias for field number 0

e_id: Tensor | None

Alias for field number 1

size: Tuple[int, int]

Alias for field number 2

to(*args, **kwargs)[source]
MARBLE.utils.expand_index(ind, dim)[source]

Interleave dim incremented copies of ind

MARBLE.utils.to_block_diag(sp_tensors)[source]

To block diagonal.

MARBLE.utils.expand_edge_index(edge_index, dim=1)[source]

When using rotations, we replace nodes by vector spaces so need to expand adjacency matrix from nxn -> n*dimxn*dim matrices

MARBLE.utils.tile_tensor(tensor, dim)[source]

Enlarge nxn tensor to d*dim x n*dim block matrix. Effectively computing a sparse version of torch.kron(K, torch.ones((dim,dim)))

MARBLE.utils.restrict_dimension(sp_tensor, d, m)[source]

Limit the dimension of the tensor

MARBLE.utils.restrict_to_batch(sp_tensor, idx)[source]

Restrict tensor to current batch

MARBLE.utils.standardize(X)[source]

Standarsise data row-wise