Pretrained Vision Transformer Neural Quantum State on the - Heinseberg model on a square lattice. The frustration ratio is set to .
Revision | Variational energy | Time per sweep | Description |
---|---|---|---|
main | -0.497505103 | 41s | Plain ViT with translation invariance among patches |
symm_t | -0.49760546 | 166s | ViT with translational symmetry |
symm_trxy_ising | -0.497676335 | 3317s | ViT with translational, point group and sz inversion symmetries |
The time per sweep is evaluated on a single A100-40GB GPU.
The architecture has been trained by distributing the computation over 40 A100-64GB GPUs for about four days.
Citation
https://www.nature.com/articles/s42005-024-01732-4
How to Get Started with the Model
Use the code below to get started with the model. In particular, we sample the model using NetKet.
import jax
import jax.numpy as jnp
import netket as nk
import flax
from flax.training import checkpoints
flax.config.update('flax_use_orbax_checkpointing', False)
# Load the model from HuggingFace
from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True)
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)
lattice = nk.graph.Hypercube(length=10, n_dim=2, pbc=True, max_neighbor_order=2)
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes, total_sz=0)
hamiltonian = nk.operator.Heisenberg(hilbert=hilbert,
graph=lattice,
J=[1.0, 0.5],
sign_rule=[False, False]).to_jax_operator() # No Marshall sign rule
sampler = nk.sampler.MetropolisExchange(hilbert=hilbert,
graph=lattice,
d_max=2,
n_chains=16384,
sweep_size=lattice.n_nodes)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
vstate = nk.vqs.MCState(sampler=sampler,
apply_fun=wf.__call__,
sampler_seed=subkey,
n_samples=16384,
n_discard_per_chain=0,
variables=wf.params,
chunk_size=16384)
# Overwrite samples with already thermalized ones
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id="nqs-models/j1j2_square_10x10", filename="spins")
samples = checkpoints.restore_checkpoint(ckpt_dir=path, prefix="spins", target=None)
samples = jnp.array(samples, dtype='int8')
vstate.sampler_state = vstate.sampler_state.replace(Ο = samples)
# Sample the model
for _ in range(10):
E = vstate.expect(hamiltonian)
print("Mean: ", E.mean.real / lattice.n_nodes / 4)
vstate.sample()
The expected output is:
Number of parameters = 434760
Mean: -0.4975034481394982
Mean: -0.4975697817150899
Mean: -0.49753878662981793
Mean: -0.49749150331671876
Mean: -0.4975093308123018
Mean: -0.49755810175173776
Mean: -0.49753726455462444
Mean: -0.49748956161946795
Mean: -0.497479875901942
Mean: -0.49752966071413424
The fully translational invariant wavefunction can be also be downloaded using:
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True, revision="symm_t")
Use revision="symm_trxy_ising"
for a wavefunction including also the point group and the sz inversion symmetries.
Extract hidden representation
The hidden representation associated to the input batch of configurations can be extracted as:
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True, return_z=True)
z = wf(wf.params, samples)
Starting from the vector , a fully connected network can be trained to fine-tune the model on a different value of the ratio . See https://doi.org/10.1103/PhysRevResearch.6.023057 for more informations.
Note: the hidden representation is well defined only for the non symmetrized model.
Training Hyperparameters
Number of layers: 8
Embedding dimension: 72
Hidden dimension: 288
Number of heads: 12
Total number of parameters: 434760
Model Card Contact
Riccardo Rende ([email protected])
Luciano Loris Viteritti ([email protected])
- Downloads last month
- 28