scooby.modeling¶
Submodules¶
Classes¶
Package Contents¶
- class scooby.modeling.Scooby(config, cell_emb_dim, embedding_dim=1920, n_tracks=2, disable_cache=False, use_transform_borzoi_emb=False, cachesize=2, **params)¶
Bases:
borzoi_pytorch.Borzoi- cell_emb_dim¶
- cachesize = 2¶
- use_transform_borzoi_emb = False¶
- n_tracks = 2¶
- embedding_dim = 1920¶
- disable_cache = False¶
- cell_state_to_conv¶
- _init_weights(module)¶
Initialize the weights
- forward_cell_embs_only(cell_emb)¶
Processes cell embeddings to generate convolutional filter weights and biases.
- Args:
cell_emb: Tensor of cell embeddings (batch_size, num_cells, cell_emb_dim).
- Returns:
Tuple: Convolutional filter weights and biases.
- forward_seq_to_emb(sequence)¶
Processes DNA sequences through Borzoi backbone to obtain sequence embeddings.
- Args:
sequence: Tensor of DNA sequences (batch_size, seq_len, 4).
- Returns:
Tensor: Sequence embeddings.
- forward_convs_on_emb(seq_emb, cell_emb_conv_weights, cell_emb_conv_biases, bins_to_predict=None)¶
Applies cell-state-specific convolutions to sequence embeddings.
- Args:
seq_emb: Tensor of sequence embeddings. cell_emb_conv_weights: Convolutional filter weights. cell_emb_conv_biases: Convolutional filter biases. bins_to_predict (optional): Indices of bins to predict (if None, predicts all bins).
- Returns:
Tensor: Predicted profiles.
- forward_sequence_w_convs(sequence, cell_emb_conv_weights, cell_emb_conv_biases, bins_to_predict=None)¶
Processes DNA sequence, applies cell-state-specific convolutions, and caches results.
- Args:
sequence: Tensor of DNA sequences. cell_emb_conv_weights: Convolutional filter weights. cell_emb_conv_biases: Convolutional filter biases. bins_to_predict (optional): Indices of bins to predict.
- Returns:
Tensor: Predicted profiles.
- forward(sequence, cell_emb)¶
Forward pass of the scooby model.
- Args:
sequence: Tensor of DNA sequences (batch_size, seq_len, 4). cell_emb: Tensor of cell embeddings (batch_size, num_cells, cell_emb_dim).
- Returns:
Tensor: Predicted profiles for each cell (batch_size, num_cells, seq_len, n_tracks).