scooby.modeling

Submodules

Classes

Package Contents

class scooby.modeling.Scooby(config, cell_emb_dim, embedding_dim=1920, n_tracks=2, disable_cache=False, use_transform_borzoi_emb=False, cachesize=2, **params)

Bases: borzoi_pytorch.Borzoi

cell_emb_dim
cachesize = 2
use_transform_borzoi_emb = False
n_tracks = 2
embedding_dim = 1920
disable_cache = False
cell_state_to_conv
_init_weights(module)

Initialize the weights

forward_cell_embs_only(cell_emb)

Processes cell embeddings to generate convolutional filter weights and biases.

Args:

cell_emb: Tensor of cell embeddings (batch_size, num_cells, cell_emb_dim).

Returns:

Tuple: Convolutional filter weights and biases.

forward_seq_to_emb(sequence)

Processes DNA sequences through Borzoi backbone to obtain sequence embeddings.

Args:

sequence: Tensor of DNA sequences (batch_size, seq_len, 4).

Returns:

Tensor: Sequence embeddings.

forward_convs_on_emb(seq_emb, cell_emb_conv_weights, cell_emb_conv_biases, bins_to_predict=None)

Applies cell-state-specific convolutions to sequence embeddings.

Args:

seq_emb: Tensor of sequence embeddings. cell_emb_conv_weights: Convolutional filter weights. cell_emb_conv_biases: Convolutional filter biases. bins_to_predict (optional): Indices of bins to predict (if None, predicts all bins).

Returns:

Tensor: Predicted profiles.

forward_sequence_w_convs(sequence, cell_emb_conv_weights, cell_emb_conv_biases, bins_to_predict=None)

Processes DNA sequence, applies cell-state-specific convolutions, and caches results.

Args:

sequence: Tensor of DNA sequences. cell_emb_conv_weights: Convolutional filter weights. cell_emb_conv_biases: Convolutional filter biases. bins_to_predict (optional): Indices of bins to predict.

Returns:

Tensor: Predicted profiles.

forward(sequence, cell_emb)

Forward pass of the scooby model.

Args:

sequence: Tensor of DNA sequences (batch_size, seq_len, 4). cell_emb: Tensor of cell embeddings (batch_size, num_cells, cell_emb_dim).

Returns:

Tensor: Predicted profiles for each cell (batch_size, num_cells, seq_len, n_tracks).