Neural Representations for Computational Physics:

Scalable Transformers as PDE surrogates

 

 

Vedant Puri

https://vpuri3.github.io/
MAY 04, 2026

Computer simulations are critical for industrial applications

1

Modern engineering is reliant on computer simulations

Predictive maintenance

Design space exploration

[2]

[1]

[3]

[1] CFD Direct / OpenFOAM – “OpenFOAM HPC on AWS with EFA”, cfd.direct  
[2] EurekAlert — “New concrete system may reduce wind-turbine costs”  
[3] Flow-3D, “FLOW-3D AM” product page, flow3d.com 

Process optimization

Automotive engineering

Civil engineering

Advanced manufacturing

[1]

2

Physics-based simulations bottleneck many engineering workflows

[1] COMSOL — “Mesh Refinement” 

[2] Langtangen, H. P. — INF5620: Finite Element Methods (Lecture Notes)

[3] GridPro Blog — “The Art and Science of Meshing Airfoil”  

[4] ResearchGate — “Transition to turbulence of Taylor-Green Vortex at different time (DNS)” (figure)  
[5] ORNL / U.S. Department of Energy — “DOE and Cray deliver record-setting Frontier supercomputer at ORNL”  

\partial_t \boldsymbol{u} + (\boldsymbol{u} \cdot \boldsymbol{\nabla})\boldsymbol{u} = -\boldsymbol{\nabla} p + \frac{1}{\text{Re}}\Delta \vec{v} + f\\ \boldsymbol{\nabla}\cdot\boldsymbol{u} = 0

Governing Equations

\boldsymbol{\nabla}\cdot \boldsymbol{\sigma} + \boldsymbol{F} = \rho\boldsymbol{\ddot{{u}}}

[1]

[2]

Discretization machinery

Repeated large system solves

[5]

Multiscale physics \(\implies\) small \(\Delta t\)

[4]

Complex geometry \(\implies\) fine meshes

[3]

Complex geometry \(\implies\) fine meshes

The cost of this procedure scales poorly for several reasons.

Neural signal representations learn to emulate physics from data

3

{\mathbf{u}}(\mathbf{x}) = \Sigma_{i=1}^{N} \mathbf{u}_i \phi_i(\mathbf{x})

(Explicit) Weighted sum of polynomial interpolants

Finite Elements

[1]

{\mathbf{u}}(\mathbf{x}) = (Z_L \circ \dotsc \circ Z_0)(\mathbf{x})

(Implicit) High-dim nonlinear feature learners

Multilayer Perceptron (MLP)

[2]

Cannot learn from data

Can learn from data

Large cost per simulation

Cheap evaluation after training

High-accuracy

Problem-specific

Robust

Up to \(0.1\%\) accuracy

[1] Math StackExchange — “Interpolation in Finite Element Method”  
[2] ResearchGate — “Structure of a Deep Neural Network” (figure)  

\(\text{Mesh ansatz}\)

\({u}(x)=\)

\(u(x) = \)

\(\text{Neural ansatz}\)

\(\text{Physics-based}\)

\(\text{Data-driven}\)

\(\text{Numerical}\)

\(\text{Simulation}\)

\(\text{Reduced Order}\)

\(\text{Modeling}\)

\(\text{Neural ROMs}\)

\(\text{Surrogate}\)

\(\text{Learning}\)

\(\text{Transformers}\)

\(\text{PINNs}\)

\(\text{Finite Elements}\)

\(\text{PCA/POD}\)

\(\text{Graph Networks}\)

Landscape of data-driven methods in computational physics

4

 Scalable transformer models for large-scale surrogate modeling

[1]

[3]

[2]

[1] CFD Direct / OpenFOAM — “Introduction to Computational Fluid Dynamics”  
[2] ResearchGate — “Schematic of a Vanilla Physics-Informed Neural Network” (figure)

[3] Kutz, J. N. — “Data-Driven Modeling & Dynamical Systems” (UW)  

Data-Driven Modeling

 

Scalable neural surrogates for PDEs and beyond!

Surrogate models learn PDE solution operator from data

5

\mathcal{L}(\boldsymbol{x}, t, \boldsymbol{u}; \boldsymbol{\mu}) = 0
\mathcal{G}: \boldsymbol{\mu} \mapsto \boldsymbol{u}

Training

\mathcal{G}_\theta \approx \mathcal{G}

Inference

\mathcal{G}_\theta

Large training cost is amortized over several evaluations

Model learns to predict \(\boldsymbol{u}\) over a distribution of \(\boldsymbol{\mu}\)

Transformers [1] are state-of-the-art surrogate models

6

Message-passing on a dynamic all-to-all graph.

[1] Vaswani et al. — “Attention Is All You Need”, NeurIPS 2017  

Quadratic (\(\mathcal{O}(N^2)\)) cost limits scalability

Quadratic \((\mathcal{O}(N^2))\) cost of attention limits scalability

7

Over \(20~\text{s}\) per gradient step on a mesh of 1m poins!

Goal: enable transformer models on large meshes.

[1] Vaswani et al. — “Attention Is All You Need”, NeurIPS 2017  

\([1]\)

What are the limitations on communication patterns?

8

\Delta u = f
\begin{bmatrix} &&&\\ &&&\\ &&& \end{bmatrix} \cdot \begin{bmatrix} \\ \underline{u}\\ \\ \end{bmatrix} = \begin{bmatrix} \\ \underline{f}\\ \\ \end{bmatrix}
\begin{bmatrix} \\ \underline{u}\\ \\ \end{bmatrix} = \begin{bmatrix} &&&&\\ &&&&\\ &&&&\\ \end{bmatrix} \begin{bmatrix} \\ \underline{f}\\ \\ \end{bmatrix}

Solution operator requires global communication.

Forward operator is implemented with sparse, structured communication.

Need principled strategy for reducing communication cost.

Detour: finite elements

[1] ParticleInCell.com — “Finite Element Experiments in MATLAB” (2012)  

[1]

Are \(N \times N\) messages really necessary?

9

Smoothness implies redundancy in communication.

Are \(N \times N\) messages really necessary?

9

Smoothness implies redundancy in communication.

Method: club matching points to one cluster and communicate together.

FLARE: Fast Low-rank Attention Routing Engine

10

Encoding: introduce \(M\) latent clusters to pool messages from matching tokens

\(M\) learned queries

Decoding: map pooled messages to matching output tokens

FLARE: Fast Low-rank Attention Routing Engine

11

\(\mathcal{O}(2MN) \ll \mathcal{O}(N^2)\)

\(\text{rank}(W_\text{encode}\cdot W_\text{decode}) \leq M\)

\(>200\times\) speedup

\(\text{(} M \text{ tokens)}\)

\(\text{Latent}\)

[1] Vaswani et al. — “Attention Is All You Need”, NeurIPS 2017  

\([1]\)

PDE surrogate benchmark problems

Relative \(L_2\) error \( (\times 10^{-3})\) (lower is better)

12

Pipe

Darcy

Elasticity

LPBF

DrivAerML

[1] Vaswani et al. — “Attention Is All You Need”, NeurIPS 2017 

[2] Jaegle et al. — "PercieverIO: A  General Architecture for Structured Inputs & Outputs", ICLR 2022

[3] Hao et al., — "GNOT: A General Neural Operator Transformer for Operator Learning", PMLR 2023

[4] Wang et al. —"Latent Neural Operator", NeurIPS 2024

[5] We et al. — "Transolver: A Fast Transformer Solver for PDEs on General Geometries", ICML 2024

Elasticity benchmark problem

13

Pipe flow, Darcy flow benchmark problems

14

Laser powder bed fusion benchmark problem

15

Scaled dot-product implementation

16

import torch.nn.functional as F

def flare_multihead_mixer_inefficient(Q, K, V):
  	"""
    Args - Q: [H M D], K, V: [B H N D]
	Ret - Y: [B H N D]
	"""
    
	# [B H M N]
    scores = Q @ K. mT
    # [B H M N]
	W_encode = F.softmax(scores, dim = -1)
    # [B H M N]
	W_decode = F.softmax(scores.mT , dim = -1)

	Z = W_encode @ V
	Y = W_decode @ Z
    
	return Y

def flare_multihead_mixer (Q, K, V):
  
    Z = F.scaled_dot_product_attention(Q, K, V)
	Y = F.scaled_dot_product_attention(K, Q, Z)

    return Y

\(\mathcal{O}(2MN)\) compute

\(\mathcal{O}(2MN)\) memory

FLARE learns surrogate on a million-point mesh!

Largest experiment on a single GPU!

17

[1]

[1] Ashton et al. — “DrivAerML: High-Fidelity CFD Dataset for Road-Car Aerodynamics” (arXiv:2408.11969, 2024)  

FLARE generalizes beyond PDE tasks

18

Pathfinder

\texttt{INPUT:\, [MAX 4 3 [MIN 2 3 ] 1 0 [MEDIAN 1 5 8 9, 2]] \,OUTPUT: 5}

Listops

\texttt{INPUT:\, [MAX 7 [MEDIAN 1 2 3 ] [MAX 9 2 2] [MIN 2 8]] \,OUTPUT: 9}

Image classification

Text sentiment analysis

[7]

[8]

[1]

[5]Choromanski et al. — "Rethinking Attention with Performers", ICLR 2021

[6] Tay, Y. et al. — “Long Range Arena: A Benchmark for Efficient Transformers” (arXiv 2020)  

[7] Centric Consulting — “Sentiment Analysis: Way Beyond Polarity” (blog)  
[8] Krizhevsky — CIFAR dataset homepage  

[6]

Accuracy \(​(\%)\) (higher is better)

[1] Vaswani et al. — “Attention Is All You Need”, NeurIPS 2017 

[2] Katharopoulos et al. — "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention", ICML 2020

[3] Wang et al. — "Linformer: Self-attention with linear complexity", arXiv:2006.04768 2020

[4] Qin et al. — "The devil in linear transformer", arXiv:2210.10340 2022

Takeaways

19

SOTA model for large scale encoder attention.

Comprehensive ablations, spectral analysis.

Now implemented in NVIDIA PhysicsNeMo

Ongoing work: decoder attention for transient problems.

Thank you

 

Questions?

Ongoing Work

 

Extend FLARE to transient problems

Motivation: preempt build failures in metal additive manufacturing

Laser Powder Bed Fusion (LPBF)

Dataset of 20k LPBF calculations

Goal: develop fast surrogate model to predict warpage during build

\rho C_p \frac{dT}{dt} = \nabla \cdot k\Delta T(\mathbf{x}, t) + Q(\mathbf{x}, t)
\nabla \cdot \phi = 0, \,\, \sigma = C\varepsilon_e

Governing equations

End results could be deployed as a valuable design tool for metal AM.

20

[1]

[2]

[1] Nature Scientific Data — High-resolution dataset (2025)  
[2] TechXplore — “Synergetic optimization reduces residual warpage in LPBF” (2022)  

Proposed work

Advance FLARE for enhanced surrogate modeling

21

Develop decoder version of FLARE

AIM 1(a): rank-adaptive FLARE

AIM 1(b): conditioning mechanism for FLARE

Advance FLARE for enhanced surrogate learning

Rank-adaptive FLARE for faster training

22

Complexity scales with latents (\(M\)): \(\mathcal{O}(2MN)\)

Accuracy increases with \(M\)

 Method: progressively increase latents (\(M\)) through training.

Challenge: Minimize loss spikes, training instabilities.

Background on conditioning in transformers

\left( t,\, \mathbf{x},\, \mathbf{u}_{t-k},\, \cdots, \mathbf{u}_{t} \right) \mapsto \mathbf{u}_{t+1}

Token mixing [1] (\(\mathcal{O}(N^2)\))

Conditioning [1] (\(\mathcal{O}(N\cdot C)\))

\begin{bmatrix} {t} \\ \mathbf{x}\\ \mathbf{u}_{t-k:t} \end{bmatrix}

Token mixing

\times B
\begin{bmatrix} \mathbf{u}_{t+1} \end{bmatrix}
\begin{bmatrix} \mathbf{x}\\ \end{bmatrix}

Token mixing

\begin{bmatrix} \mathbf{u}_{t+1} \end{bmatrix}

Conditioning

\begin{bmatrix} t & \mathbf{u}_{t-k:t} \end{bmatrix}
\times B

23

[1] Vaswani et al. — “Attention Is All You Need”, NeurIPS 2017  

Key idea: Modulate token-mixing with conditioning tokens

\begin{bmatrix} \mathbf{x}\\ \end{bmatrix}

Cross FLARE

\times B
\begin{bmatrix} \mathbf{u}_{t+1} \end{bmatrix}
\begin{bmatrix} t & \mathbf{u}_{t-k:t} \end{bmatrix}

24

Conditioning and cross-attention mechanism for FLARE

We propose to handle token mixing and conditioning in one unified block

\(\mathcal{O}(2MN + MC) \) complexity

Background on next-token prediction transformers [1]

y_t = \frac{\sum_{\tau = 1}^t\exp\left(q_t \cdot k_\tau \right) v_\tau}{\sum_{\tau = 1}^t \exp \left(q_t \cdot k_\tau \right)}

All previous key/value \(\{k_\tau, v_\tau \}_{\tau \leq t}\) must be cached on the GPU.

Major memory and latency bottleneck!

25

[1] Vaswani et al. — “Attention Is All You Need”, NeurIPS 2017  

Training algorithm (causal masking)

Inference algorithm (recurrence relation)

Dot-products need to be recomputed for every \(q_t\).

\(\mathcal{O}(N^2)\) complexity.

Develop decoder version of FLARE

Linear time auto-regressive attention.

Fixed memory footprint (only store \(\mathcal{O}(M)\) cache).

Flexible latent capacity.

Advantages

Required components

Fused GPU kernels for training and inference.

Bespoke training algorithm for causal FLARE.

Extensive benchmarking and evaluation.

26

Z_t = \text{online\_softmax}(Z_{t-1}, k_t, v_t)\\ y_t = \text{softmax}(Q^T \cdot k_t)^T\cdot Z_t

Inference algorithm (recurrence rule)

Next-token prediction with FLARE

Thank you

 

Questions?

Machine learning dominates several fields of scientific discovery

Enhancing PDE solvers with ML

Landscape of ML for PDEs

Mesh ansatz

PDE-Based

Neural Ansatz

Data-driven

FEM, FVM, IGA, Spectral

Fourier Neural Operator

Neural Field

DeepONet

Physics Informed NNs

Convolution NNs

Graph NNs

Adapted from Núñez, CEMRACS 2023

Neural ODEs

Universal Diff Eq

u =
\dfrac{du}{dt} =
\dfrac{d\tilde{u}}{dt} = \tilde{\mathcal{L}}_p(\tilde{u}) +
\dfrac{du}{dt} = \mathcal{L}_p(u) + \mathcal{N}_p(u)
\begin{cases} \dfrac{d u}{dt} = \mathcal{L}_p(u) + \mathcal{N}_p(u), & x\in\Omega\\ u|_{\partial\Omega} = g(t) \end{cases}

Reduced Order Modeling

Enhancing PDE solvers with ML

Newsflash: Neural signal representations beats the curse of dimensionality-ish

Orthogonal Functions Deep Neural Networks






 

 

 
f = \tilde{f} + \mathcal{O}(h)

\( N \) parameters, \(M\) points

\( h \sim 1 / N \) (for shallow networks)

\( N \) points

\( \dfrac{d}{dx} \tilde{f}\sim \mathcal{O}(N^2) \) (exact)

\( \dfrac{d}{dx} \tilde{f} \sim \mathcal{O}(N) \) (exact, AD)

\( \int_\Omega \tilde{f} dx \sim \mathcal{O}(N) \) (exact)

(Weinan, 2020)

\( \int_\Omega \tilde{f} dx \sim \mathcal{O}(M) \) (approx)

Model size scales with signal complexity

Model size scales exponentially with dimension

\( N \sim h^{-d/c} \)

\tilde{u}(x) = \Sigma_{i=1}^N u_i \phi_i(x)
\tilde{u}(x) = (Z_L \circ \dotsc \circ Z_0)(x)

Efficient transformers models

Triple Attention or Multi-linear attention

FEATURES

  • Considers N-tuples of tokens at a time.
  • More expressive than standard attention, linear attention
  • Easily parallelizable across multiple GPUs
  • Kernel-based interpretation
  • As efficient and accurate as FLARE (SOTA)

DEMONSTRATIONS

  • Encoder transformer
    • PDE Surrogate modeling, Long-range arena
  • Decoder transformer
    • Next-token prediction/ language modeling

Triple Attention Scaling Study

Challenge: Learn PDE surrogate on 5-10 m points on multiGPU cluster

Adaptive Layer Norm in Diffusion Transformer allows for token mixing + time-conditioning in one go

This is only possible with a single token as conditioning vector, and won't work when you want to condition on a sequence.

Linear transformers only store state \( S \in \mathbb{R}^{D \times D} \) but their performance is not on par with softmax attention

Linear transformers replace the softmax kernel with a feature map \(\phi(\cdot)\) such that

 



This factorization allows causal attention to be computed recurrently:

\mathrm{softmax}(QK^\top) V \approx \phi(Q)\,\big(\phi(K)^\top V\big)
S_t = S_{t-1} + \phi(k_t)^\top v_t, \qquad \mathbf{y}_t = \phi(q_t)\, S_t,

Chunkwise training for linear transformers

https://manifestai.com/articles/linear-transformers-are-faster/

Premise: strong encoder model --> strong LLM

  • Next-token prediction model
    • FLARE, Triple Attention
    • Write CUDA kernels --> get scaling plots
    • Test on language tasks
  • Extend FLARE
    • Allow model to increase/ decrease \(M\) during training
    • Create efficient conditioning mechanism (time-series PDE problems)
  • FOCUS ON NEW CONTRIBUTIONS and how we can differentiate ourselves from SOTA
  • explain novelty compared to SOTA

Computer simulations are critical for industrial applications

Mesosphere

Wind farm

Turbine

Blade

1000\,\mathrm{km}
10\,\mathrm{km}
100\,\mathrm{m}
10\,\mathrm{m}

1

Modern engineering is reliant on computer simulations

Design space exploration

Predictive maintenance

[1]

[2]

FLARE decoder recurrence

FLARE allows for tradeoff between accuracy and compute

Elasticity problem

Darcy problem

24

Low-rank structure allows for efficient eigenanalysis

Message-passing is fundamentally low-rank

25

Experiment: 2D Viscous Burgers problem \( (\mathit{Re} = 1~{k})\) 

\frac{\partial \boldsymbol{u}}{\partial t} + \boldsymbol{u} \cdot \boldsymbol{\nabla}\boldsymbol{u} = \nu \Delta \boldsymbol{u}

13

\(\text{CAE-ROM}\) [1]

\(\text{SNFL-ROM (ours)}\)

\(\text{SNFW-ROM (ours)}\)

\(\text{Relative error }\)

[1] Lee & Carlberg — Nonlinear manifold ROM via CNN autoencoders (JCP 2020)

\([1]\)

Vedant Puri Job Interview meshy.ai

By Vedant Puri

Vedant Puri Job Interview meshy.ai

Vedant Puri's job talk at Meshy.ai

  • 4