Patch: Mpt

# Convert to additive mask (0 = keep, -inf = mask) return mask.to(dtype).masked_fill(mask == 0, 0.0).masked_fill(mask == 1, float("-inf")) 3. Monkey-patch into existing MPT model (example) ---------------------------------------------------------------------- def apply_mpt_patches(model: nn.Module): """Replace rotary and mask functions in an existing MPT model.""" # Patch rotary class if found for name, module in model.named_modules(): if "rotary" in name.lower() and hasattr(module, "cos_cached"): module. class = PatchedRotaryEmbedding print(f"[PATCH] Replaced rotary in name")

# If already 4D, assume correct if attention_mask.dim() == 4: return attention_mask.to(dtype)

If you meant something else (ECU patch, firmware, audio plugin), let me know. Context: MPT (Modified Transformer) uses ALiBi or Rotary embeddings. This patch fixes rotary position cache invalidation and attention mask expansion for variable-length sequences in a custom MPT block. patch mpt

batch = attention_mask.size(0)

class PatchedRotaryEmbedding(nn.Module): """Rotary embedding with cache reset on seqlen change.""" def (self, dim: int, max_seq_len: int = 2048, base: int = 10000): super(). init () self.dim = dim self.max_seq_len = max_seq_len self.base = base self._cached_cos = None self._cached_sin = None self._cached_seq_len = None # Convert to additive mask (0 = keep,

# Case: (batch, key_len) -> expand to (batch, 1, 1, key_len) if attention_mask.dim() == 2: mask = attention_mask[:, None, None, :]

# Broadcast to query_len mask = mask.expand(batch, 1, query_length, key_length) Context: MPT (Modified Transformer) uses ALiBi or Rotary

# Test attention mask expansion mask_2d = torch.tensor([[0, 0, 1, 1]]) # batch=1, key_len=4 expanded = patch_attention_mask(mask_2d, query_len=3, key_len=4, dtype=torch.float32) print(f"Expanded mask shape: expanded.shape") # (1,1,3,4) print(expanded) | Issue | Before patch | After patch | |-------|--------------|--------------| | Rotary cache | Recomputes every call, wastes memory | Only recomputes when seqlen changes | | Mask expansion | Only supports 2D masks | Supports 2D/3D/4D, correct broadcast | | Cross-attention | Mask shape mismatch | Proper (batch,1,q_len,k_len) | If you meant a firmware patch for an MPT controller (like in automotive or industrial PLCs), I can write a .bin patching script in Python or C. Just clarify the target.