(very low effort) i designed a simple SSM head
Posted by smoothbrain_1947@reddit | LocalLLaMA | View on Reddit | 3 comments
like the title says, this is a very low effort post/project, and i am mostly a 28 year old high school graduate useless NEET, so this thing has almost no chance of outperforming attention, mamba or rwkv, nor was that its goal, i just wanted to see if i can design something that can sort of approximate a finite tape, finite step turing machine. the basic idea is, the heads in each layer has a bunch of slots, and the input (which comes from the previous layer) gets to decide which slots to overwrite, and which slots the mlp gets to read. we do our K, Q and V projections, after that, we project the k and the q vectors from d_head to n_slots with W_e, this can be higher dim or lower dim. a projection is basically a bunch of dot scores, so W_e simply tells us how similar the k and the q vectors to the slot identity vectors, which are stored withing the projection itself. after that, each projection out gets softmaxed with a unique, learnable temp. the k softmax gets to decide the overwrite strengths for the slots, and the q softmax gets to weigh the slot contents before they are summed, just like vanilla attention. the slots are just simple selective SSMs, if a(t) is the k softmax score, then:
h(t)=(1-a(t))h(t-1)+a(t)v(t)
anyway. these "heads" are used to replace the attention heads in a GPT. with d_model=384, n_layers=6, d_head=48, ffn_mult=4, n_slots=48 we get about 11M parameters. i used absolute positional encodings, i am not sure if using RoPE would have worked, i just went with the "safe" option.
here is the head module. i didnt write it, i have no coding skills, i just explained the maths to chatgpt, told it to keep the recurrences in fp32 and to soft-clamp the softmax temps. its probably not very optimized, but it works:
class DenseSlotMemoryHead(nn.Module): """ Dense (non-sparse) slot-memory head (per-sequence SSM style).
- Input x: [B, T, d_model]
- Internal projections: d_model -> d_head
- Slot routing via dense softmax over n_slots with learnable temperature
- Selective recurrence over slots (vectorized over time, scan done in fp32)
- Slots are always reset per call (slot_state=None; this is SSM-like)
Returns:
y_out : [B, T, d_head]
new_state : [B, n_slots, d_head] (unused if you reset every sequence)
aux_loss : scalar (slot usage balance loss)
"""
def __init__(
self,
d_model: int,
d_head: int,
n_slots: int,
use_bias: bool = False,
temp_min: float = 0.1,
temp_max: float = 10.0,
):
super().__init__()
self.d_model = d_model
self.d_head = d_head
self.n_slots = n_slots
self.temp_min = temp_min
self.temp_max = temp_max
# Model -> head projections
self.W_k = nn.Linear(d_model, d_head, bias=use_bias)
self.W_q = nn.Linear(d_model, d_head, bias=use_bias)
self.W_v = nn.Linear(d_model, d_head, bias=use_bias)
# Head -> slot logits (shared for write and read)
self.W_e = nn.Linear(d_head, n_slots, bias=False)
# Learnable temperatures (scalar) for write/read softmax
self.temp_write_logit = nn.Parameter(torch.zeros(()))
self.temp_read_logit = nn.Parameter(torch.zeros(()))
def _get_temps(self, dtype, device):
"""Compute write/read temperatures, softly clamped to [temp_min, temp_max]."""
write_logit = self.temp_write_logit.to(device=device, dtype=dtype)
read_logit = self.temp_read_logit.to(device=device, dtype=dtype)
span = self.temp_max - self.temp_min
temp_write = self.temp_min + span * torch.sigmoid(write_logit)
temp_read = self.temp_min + span * torch.sigmoid(read_logit)
return temp_write, temp_read
def forward(
self,
x: torch.Tensor, # [B, T, d_model]
slot_state: torch.Tensor | None = None, # [B, n_slots, d_head] or None
):
B, T, Dm = x.shape
assert Dm == self.d_model
device = x.device
dtype = x.dtype
# Slot initial state (per sequence, like an SSM)
if slot_state is None:
H0 = torch.zeros(B, self.n_slots, self.d_head, device=device, dtype=dtype)
else:
H0 = slot_state.to(device=device, dtype=dtype)
# 1) Project all timesteps to head space
k = self.W_k(x) # [B, T, d_head]
q = self.W_q(x)
v = self.W_v(x) # [B, T, d_head]
# 2) Slot logits
B_, T_, Dh = k.shape
k_e = self.W_e(k.view(B_ * T_, Dh)).view(B, T, self.n_slots) # [B, T, n_slots]
q_e = self.W_e(q.view(B_ * T_, Dh)).view(B, T, self.n_slots)
# 3) Learnable temperatures + dense softmax routing
temp_write, temp_read = self._get_temps(dtype=dtype, device=device)
eps_temp = torch.finfo(dtype).eps
tw = torch.clamp(temp_write, min=eps_temp)
tr = torch.clamp(temp_read, min=eps_temp)
k_e_scaled = k_e / tw
q_e_scaled = q_e / tr
write_weights = F.softmax(k_e_scaled, dim=-1) # [B, T, n_slots]
read_weights = F.softmax(q_e_scaled, dim=-1) # [B, T, n_slots]
# 4) Slot usage aux loss (encourage uniform write usage)
slot_usage = write_weights.mean(dim=(0, 1)) # [n_slots]
aux_loss = ((slot_usage * self.n_slots - 1.0) ** 2).mean()
# 5) Selective recurrence over slots
a_dense = torch.clamp(write_weights, 0.0, 1.0 - 1e-5) # [B, T, n_slots]
A = 1.0 - a_dense # [B, T, n_slots]
v_expanded = v.unsqueeze(2) # [B, T, 1, d_head]
B_term = a_dense.unsqueeze(-1) * v_expanded # [B, T, n_slots, d_head]
# Slot-major layout
A_slot = A.permute(0, 2, 1).contiguous() # [B, n_slots, T]
B_slot = B_term.permute(0, 2, 1, 3).contiguous() # [B, n_slots, T, d_head]
# Do the scan in fp32 for numerical stability
A_slot32 = A_slot.to(torch.float32)
B_slot32 = B_slot.to(torch.float32)
H0_32 = H0.to(torch.float32)
C = A_slot32.cumprod(dim=2) # [B, n_slots, T]
eps = torch.finfo(torch.float32).eps
C_safe = C.clamp(min=eps)
R = B_slot32 / C_safe.unsqueeze(-1) # [B, n_slots, T, d_head]
S = R.cumsum(dim=2) # [B, n_slots, T, d_head]
H0_exp = H0_32.unsqueeze(2) # [B, n_slots, 1, d_head]
H_seq32 = C.unsqueeze(-1) * (H0_exp + S) # [B, n_slots, T, d_head]
H_seq = H_seq32.to(dtype=dtype) # [B, n_slots, T, d_head]
new_state = H_seq[:, :, -1, :] # [B, n_slots, d_head]
# 6) Readout
H_bt = H_seq.permute(0, 2, 1, 3).contiguous() # [B, T, n_slots, d_head]
y_out = torch.sum(read_weights.unsqueeze(-1) * H_bt, dim=2) # [B, T, d_head]
return y_out, new_state, aux_loss
i tested this head with the hyperparams i have given within a gpt. all heads were replaced with this one, so, no vanilla attention heads. the model was able to solve 24 digit addition within 40k steps with a batch size of 192, lr=3e-4 to 3e-5 using cosine annealing and adamw as the optimizer. i ran it at bf16 on my 3060. the samples were created as:
24digits+24digits=25digits
to keep the length fixed and make the models job easier. i did a 16 digit run too, and the same model solved it under 25k steps.
like i said, i am not expecting this thing to go anywhere, and i am just someone who occasionally tinkers with ml. i dont think there is anything new or exciting about this model, its highly unlikely to perform better than anything, but it works, and i came up with it myself, though i was obviously heavily inspired by the selective recurrences used in mamba, rwkv etc. its possible that this thing just replicates them and i wouldnt even know, because i didnt actually read their papers.
Icy_Gas8807@reddit
Hey I got lost while reading before SSM, do you have any architecture diagram? Just to understand the flow, and there is no layer norm? Skip connection? How do you plan to handle exploding and vanishing gradients?
smoothbrain_1947@reddit (OP)
this is just a replacement for vanilla attention, you still need the rest of the stuff thats inside a regular multihead attention gpt. you have the absolute positional embeddings, then a layernorm, then multiple of these heads, a residual/skip, another layernorm, FFN, another residual, just like a vanilla transformer. it probably performs much worse than a regular transformer, but it works.
the idea is very simple. we take the current k vector, dot it with all the slot ID vectors, and do a softmax with learnable temperature. the higher this softmax score is, the more the contents of that slot is overwritten with the current value vector:
h(t)=(1-a(t))⊙h(t-1)+a(t)⊙v(t) where h(t) is a single slot vector and a(t) is the softmax score for that slot.
we then weigh the slot vectors with the softmax of the dots of the current q vector with the slot ID vectors, and thats the head output
Icy_Gas8807@reddit
so you are trying to store pattern and retrive pattern instead of hoping ateention + mlp learns it probabilisticly? the goal here is to teach model to save certain parts like values and retrive it for performing that algorithm? let's say if we train it to solve 40 list sort, it must be able to sort 50 too, but my question is doesn't make more sense to reason and output the algorithm then use tool to solve it?