1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
| import torch import math
class FlashAttention(torch.nn.Module): """Flash Attention实现""" def __init__(self, hidden_size, num_heads, block_size=256): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.block_size = block_size def forward(self, q, k, v): """Flash Attention前向传播""" batch_size, num_heads, seq_len, head_dim = q.shape output = torch.zeros_like(q) for i in range(0, seq_len, self.block_size): end_i = min(i + self.block_size, seq_len) q_block = q[:, :, i:end_i, :] for j in range(0, seq_len, self.block_size): end_j = min(j + self.block_size, seq_len) k_block = k[:, :, j:end_j, :] v_block = v[:, :, j:end_j, :] scores = torch.matmul(q_block, k_block.transpose(-2, -1)) scores = scores / math.sqrt(head_dim) if i == j: mask = torch.triu(torch.ones_like(scores), diagonal=1) scores = scores.masked_fill(mask.bool(), float('-inf')) attn_weights = torch.nn.functional.softmax(scores, dim=-1) block_output = torch.matmul(attn_weights, v_block) output[:, :, i:end_i, :] += block_output return output
class SparseAttention(torch.nn.Module): """稀疏注意力机制""" def __init__(self, hidden_size, num_heads, sparsity_pattern="local"): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.sparsity_pattern = sparsity_pattern def create_sparsity_mask(self, seq_len, device): """创建稀疏性掩码""" if self.sparsity_pattern == "local": mask = torch.zeros(seq_len, seq_len, device=device) window_size = min(128, seq_len // 4) for i in range(seq_len): start = max(0, i - window_size) end = min(seq_len, i + window_size + 1) mask[i, start:end] = 1 elif self.sparsity_pattern == "strided": mask = torch.zeros(seq_len, seq_len, device=device) stride = max(1, seq_len // 64) for i in range(seq_len): for j in range(0, seq_len, stride): mask[i, j] = 1 return mask.bool() def forward(self, q, k, v): batch_size, num_heads, seq_len, head_dim = q.shape sparsity_mask = self.create_sparsity_mask(seq_len, q.device) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim) scores = scores.masked_fill(~sparsity_mask, float('-inf')) attn_weights = torch.nn.functional.softmax(scores, dim=-1) output = torch.matmul(attn_weights, v) return output
|