Foundation Neural-Network Quantum State trained on the two-dimension J1J_1- J2J_2- J3J_3 Heisenberg on a 10Γ—1010\times 10 square lattice. The system is described by the following Hamiltonian (with periodic boundary conditions):

H^=J1β€‰β£β€‰β£βˆ‘βŸ¨r,⟩S^rβ‹…S^+J2β€‰β£β€‰β£β€‰β£β€‰β£βˆ‘βŸ¨βŸ¨r,βŸ©βŸ©β€‰β£β€‰β£β€‰β£S^rβ‹…S^+J3β€‰β£β€‰β£β€‰β£β€‰β£β€‰β£β€‰β£βˆ‘βŸ¨βŸ¨βŸ¨r,βŸ©βŸ©βŸ©β€‰β£β€‰β£β€‰β£β€‰β£β€‰β£S^rβ‹…S^ \hat{H} = J_1\!\!\sum_{\langle {\boldsymbol{r}},{\boldsymbol{r'}} \rangle} \hat{\boldsymbol{S}}_{\boldsymbol{r}}\cdot\hat{\boldsymbol{S}}_{\boldsymbol{r'}} + J_2 \!\!\!\!\sum_{\langle \langle {\boldsymbol{r}},{\boldsymbol{r'}} \rangle \rangle} \!\!\!\hat{\boldsymbol{S}}_{\boldsymbol{r}}\cdot\hat{\boldsymbol{S}}_{\boldsymbol{r'}} + J_3 \!\!\!\!\!\!\sum_{\langle \langle \langle {\boldsymbol{r}},{\boldsymbol{r'}} \rangle \rangle \rangle} \!\!\!\!\!\hat{\boldsymbol{S}}_{\boldsymbol{r}}\cdot \hat{\boldsymbol{S}}_{\boldsymbol{r'}}

The architecture has been trained on R=4000R=4000 systems equispaced in the interval J2∈[0.0,1.0]J_2 \in [0.0, 1.0] and J3∈[0.0,0.6]J_3 \in [0.0, 0.6], using a total batch size of M=16000M=16000 samples.

How to Get Started with the Model

Use the code below to get started with the model. In particular, we sample the architecture for fixed values of J2J_2 and J3J_3 using NetKet.

from functools import partial
import numpy as np

import jax
import jax.numpy as jnp
import netket as nk
from huggingface_hub import hf_hub_download

import flax
from flax.training import checkpoints

flax.config.update('flax_use_orbax_checkpointing', False)

lattice = nk.graph.Hypercube(length=10, n_dim=2, pbc=True, max_neighbor_order=3)

J2 = 0.5
J3 = 0.0

assert J2 >= 0. and J2 <= 1.0 #* the model has been trained on this interval
assert J3 >= 0. and J3 <= 0.6 #* the model has been trained on this interval

from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2j3_square_fnqs", trust_remote_code=True)
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)

hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes, total_sz=0)
hamiltonian = nk.operator.Heisenberg(hilbert=hilbert, 
                                    graph=lattice, 
                                    J=[1.0, J2, J3], 
                                    sign_rule=[False, False, False]).to_jax_operator() # No Marshall sign rule

sampler = nk.sampler.MetropolisExchange(hilbert=hilbert,
                                        graph=lattice,
                                        d_max=2,
                                        n_chains=16000,
                                        sweep_size=lattice.n_nodes)

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
coups = np.array([J2, J3])
vstate = nk.vqs.MCState(sampler=sampler, 
                        apply_fun=partial(wf.__call__, coups=coups), 
                        sampler_seed=subkey,
                        n_samples=16000, 
                        n_discard_per_chain=0,
                        variables=wf.params,
                        chunk_size=16000)

# Overwrite samples with already thermalized ones
path = hf_hub_download(repo_id="nqs-models/j1j2j3_square_fnqs", filename="spins")
samples = checkpoints.restore_checkpoint(ckpt_dir=path, prefix="spins", target=None)
samples = jnp.array(samples, dtype='int8')
vstate.sampler_state = vstate.sampler_state.replace(Οƒ = samples)

import time
# Sample the model
for _ in range(100):
    start = time.time()
    E = vstate.expect(hamiltonian)
    vstate.sample()
    
    print("Mean: ", E.mean.real / lattice.n_nodes / 4, "\t time=", time.time()-start)

Extract hidden representation

The hidden representation associated to the input batch of configurations can be extracted as:

wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_fnqs", trust_remote_code=True, return_z=True)

z = wf(wf.params, samples, J2)

Training Hyperparameters

Number of layers: 8
Embedding dimension: 72
Hidden dimension: 288
Number of heads: 12
Patch size: 2x2

Total number of parameters: 434,904

Model Card Contact

Riccardo Rende ([email protected])
Luciano Loris Viteritti ([email protected])

Downloads last month
0
Safetensors
Model size
435k params
Tensor type
F64
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support