Foundation Neural-Network Quantum State trained on the Ising in transverse field model on a chain with L=100L=100 sites. The system is described by the following Hamiltonian (with periodic boundary conditions):

H^=βˆ’Jβˆ‘i=1NS^izS^i+1zβˆ’hβˆ‘i=1NS^ix , \hat{H} = -J\sum_{i=1}^N \hat{S}_i^z \hat{S}_{i+1}^z - h \sum_{i=1}^N \hat{S}_i^x \ ,

where S^ix\hat{S}_i^x and S^iz\hat{S}_i^z are spin- 1/21/2 operators on site ii.

The model has been trained on R=6000R=6000 different values of the field hh equispaced in the interval h∈[0.8,1.2]h \in [0.8, 1.2], using a total batch size of M=12000M=12000 samples.

The computation has been distributed over 4 A100-64GB GPUs for few hours.

How to Get Started with the Model

Use the code below to get started with the model. In particular, we sample the model for a fixed value of the external field hh using NetKet.

from functools import partial
import numpy as np

import jax
import jax.numpy as jnp
import netket as nk

import flax
from flax.training import checkpoints

flax.config.update('flax_use_orbax_checkpointing', False)

lattice = nk.graph.Hypercube(length=100, n_dim=1, pbc=True)

revision = "main"
h = 1.0 #* fix the value of the external field

assert h >= 0.8 and h <= 1.2 #* the model has been trained on this interval

from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True)
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)

hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes)
hamiltonian = nk.operator.IsingJax(hilbert=hilbert, graph=lattice, h=h, J=-1.0)

action = nk.sampler.rules.LocalRule()
sampler = nk.sampler.MetropolisSampler(hilbert=hilbert, 
                                       rule=action, 
                                       n_chains=12000, 
                                       n_sweeps=lattice.n_nodes)

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
vstate = nk.vqs.MCState(sampler=sampler, 
                        apply_fun=partial(wf.__call__, coups=h), 
                        sampler_seed=subkey,
                        n_samples=12000, 
                        n_discard_per_chain=0,
                        variables=wf.params,
                        chunk_size=12000)

# start from thermalized configurations
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id="nqs-models/ising_fnqs", filename="spins", revision=revision)
samples = checkpoints.restore_checkpoint(path, prefix="spins", target=None)
samples = jnp.array(samples, dtype='int8')
vstate.sampler_state = vstate.sampler_state.replace(Οƒ = samples)

import time
# Sample the model
for _ in range(10):
    start = time.time()
    E = vstate.expect(hamiltonian)
    vstate.sample()

    print("Mean: ", E.mean.real / lattice.n_nodes, "\t time=", time.time()-start)

The time per sweep is 3.5s, evaluated on a single A100-40GB GPU.

Extract hidden representation

The hidden representation associated to the input batch of configurations can be extracted as:

wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True, return_z=True)

z = wf(wf.params, samples, h)

Training Hyperparameters

Number of layers: 6
Embedding dimension: 72
Hidden dimension: 144
Number of heads: 12
Patch size: 4

Total number of parameters: 198288

Model Card Contact

Riccardo Rende ([email protected])
Luciano Loris Viteritti ([email protected])

Downloads last month
29
Safetensors
Model size
198k params
Tensor type
F64
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support