(very low effort) i designed a simple SSM head

Posted by smoothbrain_1947@reddit | LocalLLaMA | View on Reddit | 3 comments

like the title says, this is a very low effort post/project, and i am mostly a 28 year old high school graduate useless NEET, so this thing has almost no chance of outperforming attention, mamba or rwkv, nor was that its goal, i just wanted to see if i can design something that can sort of approximate a finite tape, finite step turing machine. the basic idea is, the heads in each layer has a bunch of slots, and the input (which comes from the previous layer) gets to decide which slots to overwrite, and which slots the mlp gets to read. we do our K, Q and V projections, after that, we project the k and the q vectors from d_head to n_slots with W_e, this can be higher dim or lower dim. a projection is basically a bunch of dot scores, so W_e simply tells us how similar the k and the q vectors to the slot identity vectors, which are stored withing the projection itself. after that, each projection out gets softmaxed with a unique, learnable temp. the k softmax gets to decide the overwrite strengths for the slots, and the q softmax gets to weigh the slot contents before they are summed, just like vanilla attention. the slots are just simple selective SSMs, if a(t) is the k softmax score, then:

h(t)=(1-a(t))h(t-1)+a(t)v(t)

anyway. these "heads" are used to replace the attention heads in a GPT. with d_model=384, n_layers=6, d_head=48, ffn_mult=4, n_slots=48 we get about 11M parameters. i used absolute positional encodings, i am not sure if using RoPE would have worked, i just went with the "safe" option.

here is the head module. i didnt write it, i have no coding skills, i just explained the maths to chatgpt, told it to keep the recurrences in fp32 and to soft-clamp the softmax temps. its probably not very optimized, but it works:

class DenseSlotMemoryHead(nn.Module): """ Dense (non-sparse) slot-memory head (per-sequence SSM style).

- Input x: [B, T, d_model]
- Internal projections: d_model -> d_head
- Slot routing via dense softmax over n_slots with learnable temperature
- Selective recurrence over slots (vectorized over time, scan done in fp32)
- Slots are always reset per call (slot_state=None; this is SSM-like)

Returns:
    y_out     : [B, T, d_head]
    new_state : [B, n_slots, d_head]  (unused if you reset every sequence)
    aux_loss  : scalar (slot usage balance loss)
"""

def __init__(
    self,
    d_model: int,
    d_head: int,
    n_slots: int,
    use_bias: bool = False,
    temp_min: float = 0.1,
    temp_max: float = 10.0,
):
    super().__init__()
    self.d_model = d_model
    self.d_head = d_head
    self.n_slots = n_slots

    self.temp_min = temp_min
    self.temp_max = temp_max

    # Model -> head projections
    self.W_k = nn.Linear(d_model, d_head, bias=use_bias)
    self.W_q = nn.Linear(d_model, d_head, bias=use_bias)
    self.W_v = nn.Linear(d_model, d_head, bias=use_bias)

    # Head -> slot logits (shared for write and read)
    self.W_e = nn.Linear(d_head, n_slots, bias=False)

    # Learnable temperatures (scalar) for write/read softmax
    self.temp_write_logit = nn.Parameter(torch.zeros(()))
    self.temp_read_logit = nn.Parameter(torch.zeros(()))

def _get_temps(self, dtype, device):
    """Compute write/read temperatures, softly clamped to [temp_min, temp_max]."""
    write_logit = self.temp_write_logit.to(device=device, dtype=dtype)
    read_logit = self.temp_read_logit.to(device=device, dtype=dtype)

    span = self.temp_max - self.temp_min
    temp_write = self.temp_min + span * torch.sigmoid(write_logit)
    temp_read = self.temp_min + span * torch.sigmoid(read_logit)

    return temp_write, temp_read

def forward(
    self,
    x: torch.Tensor,                           # [B, T, d_model]
    slot_state: torch.Tensor | None = None,    # [B, n_slots, d_head] or None
):
    B, T, Dm = x.shape
    assert Dm == self.d_model

    device = x.device
    dtype = x.dtype

    # Slot initial state (per sequence, like an SSM)
    if slot_state is None:
        H0 = torch.zeros(B, self.n_slots, self.d_head, device=device, dtype=dtype)
    else:
        H0 = slot_state.to(device=device, dtype=dtype)

    # 1) Project all timesteps to head space
    k = self.W_k(x)  # [B, T, d_head]
    q = self.W_q(x)
    v = self.W_v(x)  # [B, T, d_head]

    # 2) Slot logits
    B_, T_, Dh = k.shape
    k_e = self.W_e(k.view(B_ * T_, Dh)).view(B, T, self.n_slots)  # [B, T, n_slots]
    q_e = self.W_e(q.view(B_ * T_, Dh)).view(B, T, self.n_slots)

    # 3) Learnable temperatures + dense softmax routing
    temp_write, temp_read = self._get_temps(dtype=dtype, device=device)
    eps_temp = torch.finfo(dtype).eps
    tw = torch.clamp(temp_write, min=eps_temp)
    tr = torch.clamp(temp_read,  min=eps_temp)

    k_e_scaled = k_e / tw
    q_e_scaled = q_e / tr

    write_weights = F.softmax(k_e_scaled, dim=-1)  # [B, T, n_slots]
    read_weights  = F.softmax(q_e_scaled, dim=-1)  # [B, T, n_slots]

    # 4) Slot usage aux loss (encourage uniform write usage)
    slot_usage = write_weights.mean(dim=(0, 1))    # [n_slots]
    aux_loss = ((slot_usage * self.n_slots - 1.0) ** 2).mean()

    # 5) Selective recurrence over slots
    a_dense = torch.clamp(write_weights, 0.0, 1.0 - 1e-5)  # [B, T, n_slots]
    A = 1.0 - a_dense                                      # [B, T, n_slots]

    v_expanded = v.unsqueeze(2)                            # [B, T, 1, d_head]
    B_term = a_dense.unsqueeze(-1) * v_expanded            # [B, T, n_slots, d_head]

    # Slot-major layout
    A_slot = A.permute(0, 2, 1).contiguous()               # [B, n_slots, T]
    B_slot = B_term.permute(0, 2, 1, 3).contiguous()       # [B, n_slots, T, d_head]

    # Do the scan in fp32 for numerical stability
    A_slot32 = A_slot.to(torch.float32)
    B_slot32 = B_slot.to(torch.float32)
    H0_32 = H0.to(torch.float32)

    C = A_slot32.cumprod(dim=2)                            # [B, n_slots, T]
    eps = torch.finfo(torch.float32).eps
    C_safe = C.clamp(min=eps)

    R = B_slot32 / C_safe.unsqueeze(-1)                    # [B, n_slots, T, d_head]
    S = R.cumsum(dim=2)                                    # [B, n_slots, T, d_head]

    H0_exp = H0_32.unsqueeze(2)                            # [B, n_slots, 1, d_head]
    H_seq32 = C.unsqueeze(-1) * (H0_exp + S)               # [B, n_slots, T, d_head]

    H_seq = H_seq32.to(dtype=dtype)                        # [B, n_slots, T, d_head]
    new_state = H_seq[:, :, -1, :]                         # [B, n_slots, d_head]

    # 6) Readout
    H_bt = H_seq.permute(0, 2, 1, 3).contiguous()          # [B, T, n_slots, d_head]
    y_out = torch.sum(read_weights.unsqueeze(-1) * H_bt, dim=2)  # [B, T, d_head]

    return y_out, new_state, aux_loss

i tested this head with the hyperparams i have given within a gpt. all heads were replaced with this one, so, no vanilla attention heads. the model was able to solve 24 digit addition within 40k steps with a batch size of 192, lr=3e-4 to 3e-5 using cosine annealing and adamw as the optimizer. i ran it at bf16 on my 3060. the samples were created as:

24digits+24digits=25digits

to keep the length fixed and make the models job easier. i did a 16 digit run too, and the same model solved it under 25k steps.

like i said, i am not expecting this thing to go anywhere, and i am just someone who occasionally tinkers with ml. i dont think there is anything new or exciting about this model, its highly unlikely to perform better than anything, but it works, and i came up with it myself, though i was obviously heavily inspired by the selective recurrences used in mamba, rwkv etc. its possible that this thing just replicates them and i wouldnt even know, because i didnt actually read their papers.