Optimizing Large-Scale Pretraining at Character.ai
Before Character.ai shifted its focus toward building on open-source model foundations, the company’s early pretraining team explored a range of techniques to make large-scale transformer training faster and more efficient. That work - led in part by our cofounder Noam Shazeer - is now being shared publicly for the first time.
This post highlights several of those techniques, including Squinch, a 6-bit gradient compression algorithm, along with complementary methods for quantization, regularization, and distillation that shaped our early approach to scalable model training.
While Character.ai no longer does large-scale pretraining, the ideas developed during that period live on in our codebase and continue to inform how we train open-source models today. You can explore our current open-source work in pipelining-sft and Ovi, or join the team to help develop the next generation of conversational AI systems.
The following sections summarize five of those techniques.
Gradient Compression: Squinch
Squinch is a 6-bit gradient compression algorithm invented by Noam Shazeer during Character.ai’s early pretraining efforts. It was designed to maintain the same model accuracy as training with bfloat16 gradients while dramatically reducing communication bandwidth between nodes.
At the time, Character.ai’s largest pretraining cluster operated with only one-quarter of the bandwidth of state-of-the-art systems. Squinch enabled efficient distributed training under these constraints by block-wise quantizing gradients to 6 bits per element. Each block encodes eight gradient values into a compact 48-bit representation that captures both sign and magnitude. A compressed squinch block looks like this:
--------------------8-bits-------------------
+-------------------------------------------+
|-------------------q_max-------------------|
+-------------------------------------------+
|--------------------signs-------------------|
+-------------------------------------------+
|---q_elems[0] (4-bits) | q_elems[1] (4-bits)---|
+-------------------------------------------+
|---q_elems[2] (4-bits) | q_elems[3] (4-bits)---|
+-------------------------------------------+
|---q_elems[4] (4-bits) | q_elems[5] (4-bits)---|
+-------------------------------------------+
|---q_elems[6] (4-bits) | q_elems[7] (4-bits)---|
+-------------------------------------------|where
max_in_block = max(|elems|)
q_max = clip(int(6 * log(max_in_block) + 129), 0, 255)
maxabs = exp((q_max - 128) / 6)
q_elems[i] = min(int(sqrt(|elems[i]| / maxabs) * 15 + 0.5), 15)Although network bandwidth is less of a bottleneck for some clusters, the approach remains relevant for cross-domain training, sparse MoE models, and any environment where bandwidth is limited.
Unlike general quantization schemes such as bitsandbytes or DeepSpeed qgZ , Squinch’s dynamic range is tuned specifically to transformer gradients - where values tend to fall within a well-regularized distribution.
Squinch resulted in lower communication cost with negligible loss in training fidelity, when used on well-regularized transformers. Later work in DeepSeek logfmt attempted to leverage it for other parts of the model recipes.
Precision Regularization: Attention Z-Reg
Attention Z-Reg is a regularization technique applied to attention logits to keep their numerical range well-behaved during training. It shifts the logits so that the summed activation (“Z” value) remains close to 0, allowing the optimization to use the high-precision range of bfloat16 representation.
This matters because the numeric resolution of bfloat16 decreases at large magnitudes: for example, the floating point steps between 40 and 41 are far greater than between 0 and 1.
In prior art, ST-MoE introduced a z-loss term that’s applied to router logits. At Character.ai, we applied it to attention logits and linear model logits. Note that it is not a loss term. It is only added to the gradient as part of the optimization process. The “virtual” loss term is:
loss += attention_z_reg * square(logsumexp(logits)) / (num_heads * num_layers)But in practice, the gradient can be computed directly as part of the backward process during attention backward.
Quantization Stability: Dynamic Clamping
Dynamic Clamping is a quantization-aware training technique used to prevent small activation values from collapsing to zero during training.
Consider a FFN network with ReLU2 activation. We apply clamping at quantization on both input projection and output projection:
In our QAT recipe, clamping limit affects the numerical precision of subsequent quantization. When w_in_DF.rms gets really small, the middle_in_AF and middle_out_AF values will be cluttered within a very small range around 0. Then, when we quantize middle_out_AF with limit 16, most of the values will be quantized to 0. This harms training stability and accuracy.
In dynamic clamping, instead of clamping middle_in_AF to constant limits of [-4, 4], we dynamically calculate the range based on the rms of w_in_DF:
To plug into the example above:
To summarize, dynamic clamping leverages the value of input to clamp the up projected states in FFN. This technique greatly reduces quantization errors and improves training stability.
Efficient Attention API: Visibility Mask
Visibility mask is a compact way to represent the inter-chunk relationship within an item in the batch. It is composed of two tensors: visibility_start and visibility_limit that describe, for each token, the valid attention range during both training and inference.
The two tensors have shape (batch, context length) and together encode which tokens can attend to which others. Start means positions at index < visibility_start cannot attend over this token. Limit means marks boundaries of chunks. Positions at index >= visibility_limit cannot attend over this token.
Visibility mask has several advantages:
- Natively represent tree-style document relationships, common in chat data
- Improved training system efficiency by packing multiple unrelated chunks together
- Natural support for bidirectional attention, when needed
- Help with inference schemes when there is a tree structure sampling algorithm
Example 1: A single document with causal masking
Tokens: [A B C D E]
visiblity_start: [0 1 2 3 4]
visiblity_limit: [5 5 5 5 5]Example 2: Two independent documents with causal masking
Tokens: [A B C A' B' C']
visiblity_start: [0 1 2 3 4 5]
visiblity_limit: [3 3 3 6 6 6]Example 3: A tree-structured document that A is parent document to both B and C
Tokens: [A AA AAA B BB BBB C CC CCC]
visiblity_start: [0 1 2 3 4 5 6 7 8]
visiblity_limit: [9 9 9 6 6 6 9 9 9]Example 4: A beam search implementation in inference with empty slots in paged attention
Tokens: [A B C _ _ D’ D’’ D’’’] # _ means empty slot
visiblity_start: [0 1 2 8 8 5 6 7]
visiblity_limit: [8 8 8 8 8 6 7 8]Example 5: Bidirectional prefix attention with causal attention
Tokens: [A B C D' E' F'] # 3 bidirectional tokens followed by causal
visiblity_start: [0 0 0 3 4 5]
visiblity_limit: [6 6 6 6 6 6]Distillation Optimization: Gumbel Softmax
In model distillation, the soft targets for the student model are equal to the probabilities output by the teacher model. This is feasible if we are running the teacher model online while training the student model.
Alternatively, we run the teacher model offline and save the output probabilities. This is potentially simpler and saves computation if we end up doing multiple student training runs. However, for large vocabulary sizes, the output of the teacher model is large, and storing it is expensive. We can reduce the size of the output by subsampling.
To avoid biasing the student, we subsample in a manner that preserves the expected values of the soft targets. Given a teacher distribution P over the vocabulary V, we randomly sample a subset S from V using a sampling algorithm Q. Let Q(v) denote the probability that v is in S. The soft targets T for the student model are defined by T(v) = P(v) / Q(v) if v is in S and T(v) = 0 otherwise.
We can use gumbel top-k sampling as the sampling algorithm Q. This is approximately equivalent to sampling from P without replacement. We take the log-probabilities, add gumbel noise, and pick the top k values. To estimate the probability Q(v) for some token v in S, we assume that the gumbel noise on all other tokens remains the same and compute the probability that if we resampled the gumbel noise on v, the new value would exceed the (k+1)st highest noised-up log probability.
This algorithm has been proposed in [1611.01144] Categorical Reparameterization with Gumbel-Softmax by Eric Jang, Shixiang Gu, Ben Poole.
Here’s a snippet of code:
def sample_gumbel_topk(
log_prob_XV: torch.Tensor,
k: int,
) -> torch.Tensor:
"""Add gumbel noise, then sample top k tokens.
Intended use is for distillation. Run the teacher model, call this function, and save the three output tensors with the training data.
At training time, use the function soft_targets_from_gumbel_topk to reconstruct dense soft targets.
Args:
log_prob_XV: [..., vocab_size] the log probability from the teacher model
k: number of tokens to sample
Returns:
token_Xk: [..., k] selected tokens to use as soft targets for the student model
log_prob_Xk: [..., k] log probabilities of the selected tokens according to the teacher model
threshold_Xk: [..., k] noised-up log probabilities for the next highest token
"""
K = k + 1
V = log_prob_XV.size(-1)
assert K <= V
noise_XV = -torch.empty_like(log_prob_XV).exponential_().log_()
noisy_log_prob_XV = noise_XV # Overwrite this buffer.
noisy_log_prob_XV.add_(log_prob_XV)
noisy_log_prob_XK, token_XK = torch.topk(noisy_log_prob_XV, K, dim=-1)
log_prob_XK = torch.gather(log_prob_XV, -1, token_XK)
token_Xk = token_XK[..., :k].contiguous()
log_prob_Xk = log_prob_XK[..., :k].contiguous()
threshold_Xk = noisy_log_prob_XK[..., 1:].contiguous()
return token_Xk, log_prob_Xk, threshold_Xk
def soft_targets_from_gumbel_topk(
token_XK: torch.Tensor,
log_prob_XK: torch.Tensor,
threshold_XK: torch.Tensor,
k: int,
V: int,
normalize: bool = True,
dtype: torch.dtype = torch.bfloat16,
):
"""Compute soft targets from gumbel topk samples
See sample_gumbel_topk for the definition of the inputs.
k <= K
Args:
token_XK: [..., K]
log_prob_Xk: [..., K]
threshold_XK: [..., K]
k: number of tokens to sample k < K
V: vocab size
normalize: whether to normalize the weights to sum to 1
Returns:
soft_targets_XV: [..., V]
"""
K = log_prob_XK.size(-1)
assert k <= K
assert k > 0
threshold_X1 = threshold_XK[..., k - 1].unsqueeze(-1)
token_Xk = token_XK[..., :k]
log_prob_Xk = log_prob_XK[..., :k]
min_noise_Xk = threshold_X1 - log_prob_Xk
# Increase precision to avoid NaN.
max_exponential_Xk = (-min_noise_Xk).double().exp()
log_inclusion_prob_Xk = torch.log1p(-(-max_exponential_Xk).exp())
weight_Xk = torch.exp(log_prob_Xk - log_inclusion_prob_Xk)
if normalize:
# normalize to make it add up to 1. Unknown whether this is necessary for good training.
weight_Xk = weight_Xk / weight_Xk.sum(dim=-1, keepdim=True)
weight_Xk = weight_Xk.to(dtype)
soft_targets_XV = torch.zeros(token_XK.size()[:-1] + (V,), device=token_XK.device, dtype=dtype)
soft_targets_XV.scatter_add_(-1, token_Xk.long(), weight_Xk)
return soft_targets_XVThis method substantially cuts storage and bandwidth costs for offline distillation runs, while maintaining fidelity to the teacher model’s probability distribution.
Other Notes
Several additional techniques used during Character.ai’s large-scale pretraining runs have since become widely adopted or validated in public research. We’ve shared several of these in other posts, but they include:
- INT8 native training using quantized forward passes with per-tensor scaling for both activations and weights. Residual stream learnable scalars were introduced to improve numerical stability across layers.
- KV sharing with interleaving structure, which performs better than letting first half layers be producers and the latter half be consumers (internally known as context encoder, and also known as You Only Cache Once).
- mu-P style hyperparameter tuning and careful scaling law analysis.
- Batch size warmup with shorter learning rate warmup helps with stability when scaling up to larger models. Making sure each rank has a similar distribution of datasets via a global shuffling and data processing helps with smoothness of the learning kernel.
- Synthetic data augmentation, including rephrasing of training sets into different forms of questions, significantly boosts the models’ ability to tackle harder questions. Regarding mixtures, we take an experimental approach with careful ablations.
Conclusion
These techniques evolved through practical challenges in scaling conversational model pretraining. Each optimization - whether in gradient compression, quantization, or distillation - reflects Character.ai’s engineering philosophy: small, precise improvements compound into major efficiency gains at scale.
Today, Character.ai is pivoting these optimization capabilities toward a growing post-training RL effort applied to open-source models. Even though we aren't doing pretraining, the need for efficient, high-scale model systems is higher than ever. If you are passionate about solving these engineering challenges in a post-training context, come help us build the future. Explore open roles at jobs.ashbyhq.com/character.