Weekly Meeting

JAN 15, 2026

Vedant Puri

https://github.com/vpuri3

Mechanical Engineering, Carnegie Mellon University

Advisors: Prof. Burak Kara, Prof. Jessica Zhang

Weekly meeting - 01/15/26

PROPOSED WORK

  • Aim 1(a): 123
  • Aim 1(b): 123
  • Aim 2: 123


NEXT PAPER - ICML (Int'l conference of Machine Learning)

  • ABC

PROGRESS

  • 123
  • 123
  • 123

NEXT STEPS

  • 123
  • 123
  • 123

Proposed aims

Aim 2: Causal self attention with FLARE

AIM 1(a): rank-adaptive

AIM 1(b): conditioning mechanism

  • Will focus on this last. 
  • Algorithm is prepared and ready for evaluation.
  • No standardized benchmark suite. Evaluate in different modalities
    • Diffusion: image, video generation (Datta)
    • Time-series PDE problems
  • Design inference algorithm (DONE) and write CUDA kernel (EASY)
  • Design training algorithm and write GPU kernels (Vedant)
  • Evaluate on language modeling applications

Proposed timeline

Notes

  • Our algorithmic work is cut out and straightforward, which is good.
  • Furthermore, the advantages of our methods are well laid out.
  • However, learning several new domains (image/video, generation, language modeling, PDE timeseries) is hard.
  • I'm spending considerable amount of time managing Datta.
    • His task last semester was well bounded (hyperparameter tuning)
    • This semester, I need him to write several evaluation suites for testing my algorithms on different domains.
    • Ensuring correctness and that best practices are followed requires constant input and deep interactions.

Weekly meeting - 01/22/26

Progress

  • Setting up experiments for aims 1, 2, 3
    • Aim 1a: adaptive latent count -- test on image classification
    • Aim 1b: cross-attention -- test on image/video diffusion
    • Aim 2: next token prediction -- test on language modeling
  • Progress
    • Pipeline nearly ready for Aim 1a.
      • Training was slow because of the magnitude of data loading (1.2m images). (2 days for training small models)
      • We applied many tricks to bring that down to 5-10 hours.
    • TODO: set up diffusion pipeline on top of vision pipeline
    • Pipeline ready for Aim 2.
      • got set up in a standard testing suite for efficient attention models
  • ICLR notification for FLARE paper should come out today
    • If rejected/ notification delayed --> submit to ICML

Setup for language modeling

Benchmarking (model size: 340M params, context length: 2k tokens)

  • Train model on a fixed pretraining dataset (10B tokens) -- 5-8 hours on 4 GPUs.
  • Evaluate on several reasoning datasets.

FLARE Decoder

Advantages

  • Faster training, inference (esp. for long context)
  • Drastically lower inference memory (esp. for long context)

Weekly meeting - 01/29/26

Progress

  • Setting up experiments for aims 1, 2, 3
    • Aim 1a: adaptive latent count -- test on image classification
    • Aim 1b: cross-attention -- test on image/video diffusion
    • Aim 2: next token prediction -- test on language modeling
      • Advantages: faster training, inference (esp. for long context). Drastically lower memory requirement for long context inference.
  • Progress
    • Submitted to ICML
    • Vision transformer pipeline complete (took ~2 weeks)
    • Aim 1: can test on Darcy, or image classification
    • Aim 2: diffusion image generation pipeline complete. problems are too large. need greater allocation. working on it.
    • Aim 3: language model testing in progress.
      • CUDA kernels for FWD pass, BWD pass done
      • Testing small models (50M parameters)
      • Noticing gradient instability

FLARE Decoder: FWD pass N=2048

FLARE Decoder: BWD pass N=2048

FLARE Decoder: FWD pass N=65k

FLARE Decoder: BWD pass N=65k

Debugging training instability

Instability removed with k-normalization, but are we losing performance?

FLARE causal update

y_t = \sum_{\tau \leq t} \left(\sum_{m=1}^M \frac{P_{mt}}{\sum_{u\leq t}\exp(S_{mu})} \exp(S_{m\tau}) \right) v_{\tau}

Weekly meeting - 02/26/26

Progress

  • Setting up experiments for aims 1, 2, 3
    • Aim 1a: adaptive latent count -- test on image classification
    • Aim 1b: cross-attention -- test on image/video diffusion
    • Aim 2: next token prediction -- test on language modeling
      • Advantages: faster training, inference (esp. for long context). Drastically lower memory requirement for long context inference.
  • Progress
    • Submitted to ICML
    • Vision transformer pipeline complete (took ~2 weeks)
    • Aim 1: can test on Darcy, or image classification
    • Aim 2: diffusion image generation pipeline complete. problems are too large. need greater allocation. working on it.
    • Aim 3: language model testing in progress.
      • CUDA kernels for FWD pass, BWD pass done
      • Testing small models (50M parameters)
      • Noticing gradient instability

Manuscript summary

Memory-Efficient Causal Attention via Latent Routing (FLARE Decoder)

Target Conference: NeurIPS 2026 (May 15)

1. Contrib: Latent routing formulation of causal attention

  • We introduce a new formulation of autoregressive attention as routing through a fixed-size latent space, providing a principled low-rank factorization that preserves global context while enabling prefix-sufficient state.

2. Contrib: Constant-memory autoregressive decoding algorithm

  • We derive an exact decoding algorithm whose memory and compute per step are independent of sequence length while maintaining full-context modeling capability.

3. (TODO) Contrib: Linear-memory training via chunkwise recomputation algorithm

  • We develop an efficient training algorithm that enables exact gradient computation with bounded memory footprint.

4. (TODO) Contrib: Optimized GPU kernels for scalable training and inference

  • We provide high-performance kernels that make the proposed method practical at modern training scales.

5. (TODO) Contrib: Adaptive latent queries for content aware compression (new architecture improvement over FLARE)

  • We propose dynamic latent query generation conditioned on the prefix, allowing adaptive routing and improving expressivity over static memory token approaches.

Language: Inference latency (prefill, decode) speedup (contrib 1,2)

Language: Accuracy

Language: challenge

  • We are in the right ballpark to be competitive with other models.
  • However, our accuracy is not high enough at the moment.
  • Reasons
    • We have only been testing with few latent tokens \(M\).
    • This is because backward kernel is slow for large \(M\).
    • Also, we have not been using any gating-like tricks to modulate update.
  • Tasks
    • Optimize kernel backward pass (see manuscript)
    • Introduce gating-like tricks
    • Try adaptive latents idea.

(from gated linear attn paper)

Weekly meeting - 03/05/26

Efficient Causal Attention via Latent Routing (FLARE Decoder)

  • Latent routing formulation of causal attention
  • Constant-memory autoregressive decoding algorithm
  • Linear-memory training via chunkwise recomputation algorithm
  • (TODO) Optimized GPU kernels for scalable training and inference
  • (TODO) Adaptive latent queries (new architecture improvement over FLARE)

Progress

  • Optimized backward pass for decoder model
  • Testing ideas for: adaptive queries
    • separate encoder/decoder weights
    • gating
  • Datta - image classification results

Loss curves

Weekly meeting - 03/12/26

Efficient Causal Attention via Latent Routing (FLARE Decoder)

  • Latent routing formulation of causal attention
  • Constant-memory autoregressive decoding algorithm
  • Linear-memory training via chunkwise recomputation algorithm
  • (TODO) Optimized GPU kernels for scalable training and inference
  • (TODO) Adaptive latent queries (new architecture improvement over FLARE)

Progress

  • Optimized backward pass for decoder model
  • Recent modification (separate encoder/decoder weights) worked very well
  • Datta - image classification results

Next steps

  • Test more modifications (addition of causal convolutions, gating)
  • Test block wise causal learning modeling
    • with application to video diffusion, PDE timeseries

Training times for FLARE Decoder