Skip to main content
All layers are exported from myvllm.layers and are designed for tensor-parallel inference with torch.distributed. Each layer that performs tensor parallelism reads the process group rank and world size from torch.distributed at construction time.

Activations

SiLU-gated MLP activations

Normalization

Fused RMSNorm with residual add

Linear

Tensor-parallel column and row splits

Embeddings

Vocab-parallel lookup and LM head

RoPE

Rotary positional embeddings

Sampler

Temperature sampling

SiluAndMul

myvllm.layers.activation.SiluAndMul A fused activation layer used in MLP blocks. Expects the gate and up projections to have been concatenated along the last dimension (as produced by MergedColumnParallelLinear). Internally it splits the input in half and applies SiLU to the first chunk, then multiplies element-wise by the second chunk.
output = SiLU(x[:, :d]) * x[:, d:]
The forward method is compiled with torch.compile.

forward

forward(x: torch.Tensor) -> torch.Tensor
ArgumentShapeDescription
x(..., 2 * d)Concatenated gate + up projection output
returns(..., d)SiLU-gated activations
Example
import torch
from myvllm.layers import SiluAndMul

act = SiluAndMul().cuda()
x = torch.randn(4, 16).cuda()   # last dim must be even
out = act(x)                    # shape (4, 8)

LayerNorm

myvllm.layers.layernorm.LayerNorm Root Mean Square Layer Normalization (RMSNorm). Supports an optional fused residual add that adds the residual to x before normalizing and returns both the normalized output and the updated residual, eliminating a separate addition operation in the decoder layer.
RMSNorm(x) = (x / sqrt(mean(x²) + ε)) ⊙ γ
The rms_forward method is compiled with torch.compile.

Constructor

gamma
torch.Tensor
required
Initial scale (γ) parameter of shape (hidden_size,). Copied and wrapped in nn.Parameter so that it participates in gradient computation and checkpoint loading.
eps
float
default:"1e-5"
Small constant added inside the square root for numerical stability.

forward

forward(x: torch.Tensor, residual: torch.Tensor | None = None)
When residual is None, returns the normalized tensor directly. When residual is provided, fuses the residual add:
  1. Computes x_new = x + residual.
  2. Normalizes x_new with RMSNorm.
  3. Returns (normalized, x_new) — the second element is the updated residual for the next sub-layer.
output
torch.Tensor
Normalized tensor, same shape as x.
residual
torch.Tensor
Updated residual x + residual_in. Only returned when residual is passed as input.
Example
import torch
from myvllm.layers import LayerNorm

gamma = torch.ones(512).cuda()
norm = LayerNorm(gamma, eps=1e-6).cuda()

x        = torch.randn(4, 16, 512).cuda()
residual = torch.randn(4, 16, 512).cuda()

# Plain RMSNorm
out = norm(x)

# Fused residual add + RMSNorm
out, new_residual = norm(x, residual)

Linear layers

All linear layers inherit from LinearBase and implement a weight_loader method that knows how to extract the correct shard from a full pre-trained checkpoint tensor and copy it into the local (already-sharded) parameter buffer.

weight_loader pattern

When loading a checkpoint, iterate over the model’s named parameters and call the custom loader if present:
for name, param in model.named_parameters():
    if name in checkpoint:
        loaded_weight = checkpoint[name]   # full tensor from disk
        if hasattr(param, 'weight_loader'):
            param.weight_loader(param, loaded_weight)
        else:
            param.data.copy_(loaded_weight)
For merged projections (MergedColumnParallelLinear, QKVColumnParallelLinear), the loader accepts an extra argument that identifies which sub-matrix the checkpoint tensor belongs to. See packed_module_mapping in the model classes for how names are resolved.

ColumnParallelLinear

myvllm.layers.linear.ColumnParallelLinear Splits the output dimension across tensor-parallel ranks. Each GPU holds output_size / tp_size output rows. No collective communication is needed in the forward pass — outputs are naturally sharded.
input_size
int
required
Full input feature dimension (replicated on all GPUs).
output_size
int
required
Full output feature dimension before sharding. Must be divisible by tp_size.
bias
bool
default:"True"
Whether to add a bias term.

RowParallelLinear

myvllm.layers.linear.RowParallelLinear Splits the input dimension across tensor-parallel ranks. Each GPU holds input_size / tp_size input columns. An all_reduce is performed in the forward pass to sum the partial results and produce a replicated output.
input_size
int
required
Full input feature dimension before sharding. Must be divisible by tp_size.
output_size
int
required
Output feature dimension (replicated on all GPUs after all_reduce).
bias
bool
default:"True"
Whether to add a bias term.

MergedColumnParallelLinear

myvllm.layers.linear.MergedColumnParallelLinear An extension of ColumnParallelLinear that stores multiple matrices (e.g., the gate and up projections of an MLP) as a single fused weight matrix. This lets both projections be computed in one F.linear call. The weight_loader accepts a loaded_weight_id: int argument that specifies which sub-matrix (by index into output_sizes) the incoming checkpoint tensor corresponds to.
input_size
int
required
Input feature dimension.
output_sizes
list[int]
required
List of output sizes for each merged sub-matrix. For a standard gate+up MLP: [intermediate_size, intermediate_size].
bias
bool
default:"True"
Whether to add a bias term.
Loading from a checkpoint
# gate projection is index 0, up projection is index 1
merged.weight.weight_loader(merged.weight, checkpoint['gate_proj.weight'], loaded_weight_id=0)
merged.weight.weight_loader(merged.weight, checkpoint['up_proj.weight'],   loaded_weight_id=1)

QKVColumnParallelLinear

myvllm.layers.linear.QKVColumnParallelLinear A specialized column-parallel linear that packs the Q, K, and V projections into a single weight matrix. Accommodates grouped-query attention by allowing num_kv_heads < num_heads, so K and V occupy fewer output rows than Q. Per-GPU output size: head_size * (num_heads/tp_size + 2 * num_kv_heads/tp_size).
input_size
int
required
Hidden size of the model.
head_size
int
required
Dimension of each attention head.
num_heads
int
required
Total number of query heads across all tensor-parallel ranks.
num_kv_heads
int
Total number of key/value heads. Defaults to num_heads.
bias
bool
default:"False"
Whether to add a bias term.
The weight_loader accepts a load_weight_id: str argument — one of 'q', 'k', or 'v' — to route each checkpoint tensor to the correct offset within the fused weight.
qkv.weight.weight_loader(qkv.weight, checkpoint['q_proj.weight'], load_weight_id='q')
qkv.weight.weight_loader(qkv.weight, checkpoint['k_proj.weight'], load_weight_id='k')
qkv.weight.weight_loader(qkv.weight, checkpoint['v_proj.weight'], load_weight_id='v')

VocabParallelEmbedding

myvllm.layers.embedding_head.VocabParallelEmbedding An embedding table that partitions the vocabulary across tensor-parallel ranks. Each GPU owns ceil(num_embeddings / tp_size) token embeddings. Tokens outside a rank’s range contribute a zero vector; an all_reduce sums contributions so that every rank receives the correct full embedding.
num_embeddings
int
required
True vocabulary size before padding.
embedding_dim
int
required
Embedding dimension.
The vocabulary is padded to the nearest multiple of tp_size for even sharding. The padding rows are zeroed and do not affect output correctness.

ParallelLMHead

myvllm.layers.embedding_head.ParallelLMHead Inherits from VocabParallelEmbedding and reuses the same weight matrix for the final logit projection (weight tying). In prefill mode it automatically selects only the last token of each sequence before computing logits, reducing unnecessary computation. In a tensor-parallel setup each rank computes logits for its vocabulary shard. Rank 0 then gathers all shards via dist.gather and concatenates them before trimming to the true vocab size.
num_embeddings
int
required
Vocabulary size.
embedding_dim
int
required
Model hidden size.
Weight tying
# Share embedding weights with the LM head (common in Llama / Qwen models)
lm_head.weight = model.embed_tokens.weight

RotaryEmbedding

myvllm.layers.rotary_embedding.RotaryEmbedding Computes Rotary Position Embeddings (RoPE). Precomputes a (max_position, rotary_embedding) cache of interleaved [cos, sin] values at construction time and looks up the relevant rows at inference time via the position indices. Supports the Llama 3 long-context frequency scaling strategy (NTK-by-parts), enabled by setting is_llama3=True.

Constructor

base
int
required
RoPE base frequency (e.g., 10000 for Qwen3, 500000 for Llama 3).
rotary_embedding
int
required
Number of head dimensions to apply rotary embedding to. Typically equal to head_dim.
max_position
int
default:"2048"
Maximum sequence length the cache is pre-computed for.
is_llama3
bool
default:"False"
Enable Llama 3 NTK-by-parts frequency scaling for long-context support.

forward

forward(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]
Looks up the cached cos/sin values for the given positions and applies apply_rotary_pos_emb to both query and key. Supports both varlen (total_tokens, num_heads, head_dim) and batched (B, seq_len, num_heads, head_dim) input shapes.
ArgumentTypeDescription
positionstorch.Tensor1-D integer tensor of token positions
querytorch.TensorQuery tensor
keytorch.TensorKey tensor
returns(q_rotated, k_rotated)Rotated query and key tensors, same shapes as inputs
Example
import torch
from myvllm.layers import RotaryEmbedding

rope = RotaryEmbedding(base=10000, rotary_embedding=64, max_position=4096).cuda()

positions = torch.arange(16).cuda()          # (seq_len,)
q = torch.randn(16, 8, 64).cuda()            # (total_tokens, num_heads, head_dim)
k = torch.randn(16, 2, 64).cuda()            # (total_tokens, num_kv_heads, head_dim)

q_rot, k_rot = rope(positions, q, k)

SamplerLayer

myvllm.layers.sampler.SamplerLayer Applies temperature scaling to logits and samples the next token using Gumbel-max sampling (equivalent to categorical sampling but numerically efficient). The forward method is compiled with torch.compile.

forward

forward(logits: torch.Tensor, temperature: torch.Tensor) -> torch.Tensor
ArgumentShapeDescription
logits(batch_size, vocab_size)Raw logits from the LM head
temperature(batch_size,)Per-sequence temperature values. Use 1.0 for no scaling
returns(batch_size,)Sampled token IDs
Sampling procedure
  1. Divide logits by temperature: logits / temperature.
  2. Compute softmax probabilities.
  3. Sample via Gumbel-max: argmax(probs / Exponential(1)) — equivalent to multinomial sampling.
import torch
from myvllm.layers import SamplerLayer

sampler = SamplerLayer().cuda()
logits      = torch.randn(4, 32000).cuda()        # (batch_size, vocab_size)
temperature = torch.full((4,), 0.8).cuda()        # per-sequence temperature
tokens = sampler(logits, temperature)             # (4,)