Foundation Neural-Network Quantum State trained on the two-dimension - - Heisenberg on a square lattice. The system is described by the following Hamiltonian (with periodic boundary conditions):
The architecture has been trained on systems equispaced in the interval and , using a total batch size of samples.
How to Get Started with the Model
Use the code below to get started with the model. In particular, we sample the architecture for fixed values of and using NetKet.
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import netket as nk
from huggingface_hub import hf_hub_download
import flax
from flax.training import checkpoints
flax.config.update('flax_use_orbax_checkpointing', False)
lattice = nk.graph.Hypercube(length=10, n_dim=2, pbc=True, max_neighbor_order=3)
J2 = 0.5
J3 = 0.0
assert J2 >= 0. and J2 <= 1.0 #* the model has been trained on this interval
assert J3 >= 0. and J3 <= 0.6 #* the model has been trained on this interval
from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2j3_square_fnqs", trust_remote_code=True)
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes, total_sz=0)
hamiltonian = nk.operator.Heisenberg(hilbert=hilbert,
graph=lattice,
J=[1.0, J2, J3],
sign_rule=[False, False, False]).to_jax_operator() # No Marshall sign rule
sampler = nk.sampler.MetropolisExchange(hilbert=hilbert,
graph=lattice,
d_max=2,
n_chains=16000,
sweep_size=lattice.n_nodes)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
coups = np.array([J2, J3])
vstate = nk.vqs.MCState(sampler=sampler,
apply_fun=partial(wf.__call__, coups=coups),
sampler_seed=subkey,
n_samples=16000,
n_discard_per_chain=0,
variables=wf.params,
chunk_size=16000)
# Overwrite samples with already thermalized ones
path = hf_hub_download(repo_id="nqs-models/j1j2j3_square_fnqs", filename="spins")
samples = checkpoints.restore_checkpoint(ckpt_dir=path, prefix="spins", target=None)
samples = jnp.array(samples, dtype='int8')
vstate.sampler_state = vstate.sampler_state.replace(Ο = samples)
import time
# Sample the model
for _ in range(100):
start = time.time()
E = vstate.expect(hamiltonian)
vstate.sample()
print("Mean: ", E.mean.real / lattice.n_nodes / 4, "\t time=", time.time()-start)
Extract hidden representation
The hidden representation associated to the input batch of configurations can be extracted as:
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_fnqs", trust_remote_code=True, return_z=True)
z = wf(wf.params, samples, J2)
Training Hyperparameters
Number of layers: 8
Embedding dimension: 72
Hidden dimension: 288
Number of heads: 12
Patch size: 2x2
Total number of parameters: 434,904
Model Card Contact
Riccardo Rende ([email protected])
Luciano Loris Viteritti ([email protected])
- Downloads last month
- 0