jVMC: Versatile and performant variational Monte Carlo leveraging automated differentiation and GPU acceleration

The introduction of Neural Quantum States (NQS) has recently given a new twist to variational Monte Carlo (VMC). The ability to systematically reduce the bias of the wave function ansatz renders the approach widely applicable. However, performant implementations are crucial to reach the numerical state of the art. Here, we present a Python codebase that supports arbitrary NQS architectures and model Hamiltonians. Additionally leveraging automatic differentiation, just-in-time compilation to accelerators, and distributed computing, it is designed to facilitate the composition of efficient NQS algorithms.


Introduction
The numerical simulation of strongly correlated quantum many-body systems constitutes a major challenge in computational physics. Even when fully exploiting modern supercomputers, direct solutions that involve the complete state vector are restricted to a few dozen degrees of freedom, because the dimension of the underlying Hilbert space grows exponentially with system size [1]. Therefore, suited approximate methods are required in order to address large or infinite system sizes, which are typically of interest when studying universal behavior. Particularly desirable are controlled methods, which can become numerically exact by tuning a control parameter.
During the past decades the development of a number of powerful numerical methods enabled substantial advances of our understanding of quantum many-body systems in and out of equilibrium. Prime examples are tensor network algorithms for low-dimensional systems [2][3][4], dynamical mean field theory in high dimensions [5][6][7][8], and quantum Monte Carlo [9] for equilibrium properties in the absence of a sign problem. However, despite the broad applicability of these methods, various physical problems remain elusive for state of the art numerical approaches, among which, e.g., systems in two or three spatial dimensions, long time dynamics, and frustrated systems. Since the limitations of the established methods in these regards are well understood, novel paths need to be explored to address the existing challenges.
The recent proposal of neural quantum states (NQS) as a new ansatz class for variational quantum Monte Carlo opened new perspectives and holds the potential to overcome existing limitations [10]. With their expressive power and their applicability independent of specific spatial structures NQS are a suited basis to devise a new family of broadly applicable algorithms. Their utility has been outlined in a number of fundamental works. It has been shown that NQS can be used to describe ground states of critical systems in two dimensions [11] or states with (chiral) topological order [12][13][14] and their suitability to study frustrated magnets is under investigation [15][16][17][18][19]. Moreover, various approaches have been developed for the simulation of non-equilibrium dynamics in closed [10,[20][21][22][23][24][25][26] and open quantum systems [27][28][29]. Other directions include the incorporation of NQS in quantum chemical simulations [30][31][32], quantum state tomography [33,34], and the simulation of quantum computation [35].
A key practical aspect that NQS algorithms have in common with many other deep learning applications is the possibility to harness cutting edge supercomputing resources at large scales; in fact, exploiting the intrinsic parallelism is crucial to achieve state-of-the-art results. In this paper we introduce the jVMC Python codebase that provides efficient implementations of the typical tasks that will be at the core of any algorithm involving NQS, namely composing and evaluating arbitrary network architectures, sampling, and evaluating quantum operators. The implementations and the interfaces are designed to enable the straightforward composition of custom algorithms and the utilization of distributed compute clusters, ideally with GPU or TPU accelerators. For this purpose the code is based on the JAX library [36], that provides functionality for automatic differentiation and optimized just-in-time compilation targeted at the available compute resources. Distributed computing across multiple nodes is enabled by incorporating the Message Passing Interface (MPI) through the mpi4py library [37,38].
The repository containing the jVMC source code and a detailed documentation are available online [39,40]. The codebase and all dependencies can be installed as a Python package via >>> pip install jVMC

Variational Monte Carlo algorithms
In this section we outline common Variational Monte Carlo (VMC) algorithms in order to highlight what are the core building blocks for which the jVMC codebase provides solutions. For this purpose we discuss how to find ground states and low-lying excited states and how to simulate real time dynamics of isolated or open systems using a timedependent variational principle (TDVP).
For simplicity we will consider spin-1/2 degrees of freedom and corresponding computational basis configurations s = (s 1 , . . . , s N ) with s i =↑, ↓. The generalization to higherdimensional local Hilbert spaces is straightforward.

Quantum expectation values
The starting point of variational Monte Carlo is a variational ansatz for the coefficients ψ θ (s) with variational parameters θ = (θ 1 , . . . , θ K ), such that without further assumptions the generally unnormalized wave function reads In the following we assume that θ k ∈ R. Given a wave function in this form the expectation value of any quantum operatorÔ can be rewritten as Since p θ (s) ≡ ψ θ (s) 2 ⟨ψ(θ) ψ(θ)⟩ constitutes a probability distribution on the computational basis configurations, we can write where we introduced the local estimator and the notation ⟪⋅⟫ θ for an expectation value with respect to p θ (s). Since typical quantum operators are sparse in the computational basis, the sum ∑ s ′ in Eq. (4) can be evaluated efficiently. By contrast, the sum over all basis states, ∑ s , in Eq. (3) becomes unfeasible for large system sizes N , because the Hilbert space dimension grows exponentially with N . Therefore, one has to resort to Monte Carlo (MC) sampling of p θ (s) to efficiently estimate these expectation values; for this purpose, only the functional form of ψ θ (s) must allow for efficient evaluation [41]. Then samples {s (1) , . . . , s (N MC ) } s∼p θ (s) can be obtained using, e.g., the Metropolis-Hastings algorithm, and the expectation value can be approximated by the empirical mean

Finding low-energy states
Given a variational ansatz ψ θ (s) the ground state of a many-body HamiltonianĤ can be approximated variationally using gradient-based optimization. The goal is to find the minimal energy expectation value In this expression S k,k ′ = Re S k,k ′ is the real part of the quantum Fisher matrix The quantum Fisher matrix is the metric tensor of the Fubini-Study metric -a natural metric on the projective Hilbert space [43]. It provides information about the local geometry of the variational manifold, which is exploited by adjusting the gradient step accordingly in Eq. (10). However, this approach requires the inversion of the S-matrix, which imposes constraints on the feasible number of variational parameters that determines the matrix dimension. Moreover, the quantum Fisher matrix is often (approximately) rank-deficient, which means that changing the parameters along some directions in the variational manifold leaves the physical state invariant. This ill-conditionedness demands careful regularization when inverting S for the SR update step in Eq. (10), see Section 2.5 for details. Besides ground states, it is also possible to address low-lying excited states with VMC techniques. In the method introduced in Ref. [44] the first step is to perform a ground state search resulting in a set of optimal parameters θ * and a variational ground state ϕ(θ * )⟩. Subsequently, the ansatz for the excited state is projected onto the subspace orthogonal to the ground state by defining The excited state search is then performed with alternating steps, first computing λ and then performing one SR step with the ansatz (12).

Unitary dynamics
For the simulation of real time dynamics the goal is to obtain equations of motion for the time-dependent variational parameters θ(t) that yield an approximate solution of the Schrödinger equation whereĤ denotes the system Hamiltonian. While alternative approaches have been explored recently [22], there are two long established ways to formulate a TDVP for the Schrödinger equation -(i) based on the principle of least action [45,46], and (ii) based on the maximization of overlaps or similar measures of proximity in the underlying Hilbert space [10,46,47]. These two derivations yield the identical result in the special case, where the parametrized wave function ψ θ (s) ≡ ψ ϑ (s) is holomorphic, i.e., fulfills Cauchy-Riemann equations as function of complex variational parameters ϑ l = θ 2l + iθ 2l+1 ∈ C [46]. The corresponding TDVP equation is a first order differential equation for the variational parameters: Here, the force vector is defined as and S l,l ′ is the quantum Fisher matrix as in Eq. (11). In this case O ϑ l (s) = ∂ ϑ l log ψ ϑ (s) denotes the complex derivative of the logarithmic wave function coefficients.
The TDVP equation (15) defines a Hamiltonian dynamics of the variational parameters, which conserves the quantum expectation value of energy [45]. Notice, however, that other constants of motion of the quantum dynamics are in general not conserved under this TDVP.
When releasing the requirement of a holomorphic ansatz, the different approaches to the TDVP yield closely related, but different results. The principle of least action, (i), results in while the minimization of some distance measure, (ii), gives While both equations correspond to a valid TDVP, only Eq. (17) preserves the symplectic structure of the variational manifold and thereby the conservation of energy under time evolution; by contrast, following Eq. (18), energy is not necessarily conserved. Moreover, solving Eq. (15) for complex parameters ϑ l = θ 2l + iθ 2l+1 is equivalent to solving Eq. (17) for the corresponding real parameters θ k .
For a unified notation we can write ] is typically ill-conditioned, meaning that a suited regularization has to be applied in order to obtain a stable solution, see Section 2.5.
As an aside, notice that for the purpose of VMC we consider wave functions ψ θ (s) with no further constraints on θ. Therefore, the solution of Eq. (19) is always part of the variational manifold we consider. A prominent example of an ansatz class with additional constraints on the variational parameters θ are the matrix product states. In that case an explicit projection onto the variational manifold has to be included as part of the TDVP [47,48].
Stochastic Reconfiguration for ground state search and the TDVP for unitary time evolution are closely related: Since ∂ θ k E(θ) = 2Re[F k ], the SR update step (10) corresponds to an imaginary time step according to the TDVP equation. Therefore, the computational steps required for both are very similar and they are summarized in the pseudocode of Algorithm 1. We will see in the following section that also variational simulation of the dynamics of an open quantum system shares the same building blocks. The jVMC codebase provides efficient implementations of these fundamental ingredients for NQS algorithms.

Markovian dissipative dynamics
When taking into account the coupling to an environment, quantum states become mixed and one has to resort to the density matrix formalism for a theoretical description. Assuming a Markovian environment, the time evolution of the density matrix is given by the Lindbladian extension of the von-Neumann equation Here, [⋅, ⋅] and {⋅, ⋅} denote the commutator and the anti-commutator, respectively. Moreover the operatorsL i are the so-called jump operators, which incorporate the action of the environment and drive the system towards a steady-state.
The TDVP for unitary dynamics can be generalized to variational approaches to solve the Lindblad equation. One possibility is based on representing the density matrix in a purified form [27]. While implementing the purification approach is also within the scope of the jVMC codebase, we focus in the following on an alternative method that relies on the Positive Operator Valued Measurements (POVM)-formalism [28,29,34,49,50] for a purely probabilistic formulation of quantum mechanics.
A POVM is a set of positive semi-definite operators {M a } such that ∑ aM a = 1. Given a POVM {M a } for the local Hilbert space, a many-body POVM can be constructed aŝ M a =M a 1 ⊗ . . . ⊗M a N [34] and the outcome a = (a 1 , . . . , a N ) is associated with the probability P a = tr ρM a .
A POVM is called minimal and informationally complete if the number of POVM elements equals the number of free parameters of the density matrix. In that case the POVM forms an operator basis on the considered Hilbert space and the relation in Eq.
with the overlap matrix T ab = tr M aM b . Hence, a minimal informationally complete POVM yields an equivalent purely probabilistic formulation of quantum mechanics in which the dissipative dynamics is governed by a Master equation for the probability distribution P a ,Ṗ a = L ab P b .
In this expression, the linear operator L ab contains the physical information equivalent to the Lindblad equation (20), see Ref. [29] for details. For the variational approach, an ansatz P a θ is chosen in order to find an efficient representation of the POVM distribution, and a TDVP yields a first-order differential equation for the variational parameters θ, see Ref. [29].
Here, S denotes the Fisher matrix S k,k ′ = ⟪O a k O a k ′ ⟫ c a∼P , and F k = ⟪O a k L ab P b P a ⟫ c a∼P , where the repeated indices b inside the brackets are summed over. The double brackets denote connected correlation functions ⟪AB⟫ c = ⟪AB⟫ − ⟪A⟫⟪B⟫ of expectation values with respect to the POVM-distribution P and O a k = ∂ θ k log P a .

Regularization to invert the (quantum) Fisher matrix
In many cases the (quantum) Fisher matrix turns out the be ill-conditioned with an eigenvalue spectrum that spans all numerical orders of magnitude. Therefore, ground state search with SR and time evolution via TDVP require suited regularization schemes when inverting the Fisher matrix. Three possible approaches are outlined below. While either the diagonal shift (Section 2.5.1) or the pseudo-inverse regularization (Section 2.5.2) should typically be used for SR, the effective regularization for real time evolution, e.g., targeted elimination of noisy contributions (Section 2.5.3), is under ongoing investigation and there are no generally established procedures to date [24,51].

Diagonal shift
An effective regularization for SR is to introduce a diagonal shift, i.e., to replace the Fisher matrix S k,k ′ byS Artificially amplifying the diagonal elements of the S-matrix mitigates the rank-deficiency and renders the matrix invertible. The shift parameter ν is typically reduced during the optimization according to a suited schedule [10]. The regularization via diagonal shift is not applicable for time evolution, because the TDVP relies on an accurate solution of the actual TDVP equation at each time point; by contrast, the ground state search is more forgiving in this regard because one is only interested in the final solution and typically not in the trajectory during the optimization. Moreover, other than the regularization schemes discussed below, the diagonal shift regularization can be combined with iterative solvers for the linear equation S k,k ′θ k ′ = F k , such as conjugate gradient [10].

Pseudo-inverse
The Moore-Penrose pseudo-inverse generalizes matrix inversion to rank-deficient matrices. The pseudo-inverse S + of a hermitian/symmetric matrix S with eigendecomposition S = V DV † , where D = diag(λ 1 , . . . , λ K ), can be constructed as For the case of ill-conditioned matrices with small but non-vanishing eigenvalues this can be generalized to an approximate pseudo-inverse with where we introduced the cutoff parameter pinv and assumed ordered eigenvalues λ 1 ≥ . . . ≥ λ K . For real time evolution it can be beneficial to choose a soft cutoff instead of Eq. (27), because eigenvalues that cross the cutoff can otherwise introduce spurious discontinuities in the dynamics. One possibility is The regularization with approximate pseudo-inverses requires diagonalization of the (quantum) Fisher matrix. Since the size of the Fisher matrix is determined by the number of variational parameters, this imposes immediate restrictions on the size of the variational ansatz.

Eliminating noisy contributions
Monte Carlo fluctuations that are blown up by the inversion of small eigenvalues of the Fisher matrix constitute a major source of instabilities when simulating real time evolution [24,51]. A mitigation strategy introduced in Ref. [24] takes the signal-to-noise ratio (SNR) of the MC estimates in the TDVP equation into account for the regularization. Denoting the force vector F k transformed to the eigenbasis of S as ρ k = V † k,k ′ F k ′ , its signal-to-noise ratio SNR(ρ k ) can be estimated using the available MC data. Then, a regularization that ignores exceedingly noisy components of the TDVP equation can be defined in analogy to the pseudo-inverse with The cutoff parameter defines a lower bound for the SNR that is tolerated and the number of discarded components can be tuned systematically by varying the number of MC samples, and thereby the SNR. This SNR-based regularization can be combined with the pseudo-inverse approach discussed above.

Neural quantum states
In the discussion of the previous sections we made no reference to any specific choice of the variational ansatz ψ θ (s) (or P a θ ). While these variational approaches are formulated for arbitrary wave function ansatzes, the particular choice will in general introduce a bias to the obtained results. Numerically exact simulations, however, require the possibility to systematically reduce this bias.
Artificial neural networks (ANNs) have been proven to be universal function approximators in the limit of large network sizes [52][53][54][55][56]. Hence, by choosing ANNs as the variational ansatz for the wave function the inductive bias of the NQS can be reduced systematically by increasing the network size. Thereby, the network size plays a similar role for NQS as the bond dimension in the established tensor network techniques [2][3][4].
In its most basic form, the ANN is a function that is composed from an alternating sequence of affine-linear and non-linear transformations. Each affine-linear transformation maps a vector of activations a (l−1) to a new vector of pre-activations from which the new activations are obtained by applying a non-linear function σ(⋅) element-wise, This iterative prescription is initialized by identifying a (0) with the input of the ANN. The set of weights W i constitutes the variational parameters, which we will in the following summarize in the variable θ.
For our purposes, we assume that the ANN encodes the logarithm of the wave function coefficients, i.e., where f θ (s) represents the ANN. In the following two subsections we discuss two further central design choices, namely holomorphic vs. non-holomorphic ANNs and the possibility to equip the NQS with an autoregressive structure to enable direct sampling. On this basis, arbitrary ANN architectures can be incorporated in the NQS, for example, Restricted Boltzmann Machines [10], convolutional [15], recurrent [57], sparse [58], or graph neural networks [59].

(Non-)holomorphic neural quantum states
The wave function coefficients ψ θ (s) are in general complex numbers, which can be obtained in different ways from an ANN. Three possible options are: 1. Single holomorphic network: Allow the weights and biases to be complex numbers, ϑ ∈ C, and choose holomorphic activation functions σ. Then f ϑ itself is a holomorphic function of the variational parameters ϑ.
2. Single non-holomorphic network: Construct an ANN with real parameters, θ ∈ R, with two real outputs, This yields a complex wave function coefficient through 3. Two real (non-holomorphic) networks: Use two separate ANNs f θ (f ) and g θ (g) for real and imaginary part In this case, independent variational parameters are used for the real and the imaginary part, θ = θ (f ) , θ (g) .

Autoregressive neural quantum states
Any joint probability distribution p(s) = p(s 1 , . . . , s N ) can be factorized and written as a product of conditional probabilities: If the random variables s i have few discrete outcomes, this property can be exploited to efficiently generate uncorrelated samples s ∼ p without resorting to Markov Chain Monte Carlo. Instead, the individual realizations of s i are sampled directly in a sequential manner.
Since the s i have only few outcomes, it is straightforward to draw a sample outcome for s 1 from p 1 . Subsequently, the same is true for s j with j > 1, because p j (s j s 1 , . . . , s j−1 ) is again a distribution with few outcomes for the given s 1 , . . . , s j−1 . ANNs, which incorporate this factorization into conditionals are called autoregressive models. There are different architectures with this property, such as Neural Autoregressive Density Estimators [11,60] or recurrent neural networks [57]. Utilizing direct sampling can be advantageous in cases where Markov Chain Monte Carlo (MCMC) becomes inefficient due to exceedingly long autocorrelation times [11]. In fact, direct sampling is also more efficient than MCMC when autocorrelation times are short. In MCMC consecutive samples are collected after performing one update sweep that typically consists of one Monte Carlo step per degree of freedom. Hence, generating one sample with MCMC comes at the cost of O(N ) network evaluations, where N is the system size. By contrast, generating one sample costs just one network evaluation when exploiting the autoregressive property.
However, the autoregressive property imposes specific constraints on the network architecture. Therefore, the advantage of more efficient sampling has to be weighed against inductive biases of the architecture.

Design choices
The jVMC codebase is devised as a transparent implementation of the core computational tasks required when dealing with NQS, which simultaneously provides efficiency and large flexibility. The primary purpose is to equip the user with these building blocks that enable the composition of a large variety of algorithms.
The code is designed to fully exploit the algorithms' typical amenability to hybrid parallelization, combining distributed single program multiple data (SPMD) computing with vectorization that can benefit substantially from the availability of accelerators like GPUs. Automatic differentiation is used to enable the composition of arbitrary ansatz functions for the variational quantum states. The code relies on the JAX library [36], which provides the functionality for automatic differentiation as well as vectorization and justin-time compilation; see Appendix A for basic examples illustrating JAX's functionality.
In this section we first explain how different levels of parallelism are exploited and how this is reflected in the API before describing the core modules of the package.

Parallelism
The computationally intense part of NQS-based algorithms is typically the evaluation of the ANN on large numbers of computational basis states. Since each individual ANN evaluation is independent of the others and largely consists of vectorizable operations, these algorithms are well suited to exploit the resources of distributed compute clusters with accelerators. The parallelization scheme that is supported by the jVMC codebase is depicted in Fig. 1. In the depiction the small blue boxes represent the task to evaluate the ANN on an individual computational basis configuration. The total set of input configurations is distributed across the available compute nodes, where, in turn, the assigned configurations are split up among the locally available accelerators. The accelerators work most efficiently when evaluating the ANN on large batches of input configurations simultaneously, see Section 3.1.2 below.
This parallelization scheme applies straightforwardly to the computation of observables following Eq. (4) when a large number of sample configurations s is given and it is applicable in a similar fashion to MC sampling. To achieve optimal performance in MCMC, it is beneficial to assign multiple independent Monte Carlo chains to each accelerator. The MCMC steps of all chains on one device can be performed in sync, again allowing for vectorized ANN evaluation on mutliple input configurations.
The code was designed with this hybrid parallelization scheme as a guiding principle. An important manifestation are the required array dimensions when interfacing jVMC: All data that is related to network evaluations will have two leading dimensions, namely the device dimension and the batch dimension. These are indicated by the stacks, which are assigned to individual nodes in Fig. 1. The corresponding layers of parallelism, multiple accelerators per node and vectorization, are explained in more detail in the following subsections.

Multiple accelerators per node
The jVMC codebase supports two possibilities to deal with compute clusters where multiple accelerators are attached to each node, namely, 1. Launch one MPI process per accelerator.
2. Distribute computation across the available devices, while working with a single process.
While the former is straightforward using the mpi4py package, the latter is enabled by automatic parallelization across devices using the pmap functionality of the JAX library. The jVMC.set pmap devices() function enables the user at the beginning of a program to choose for each MPI process which subset of the available devices to work with. For a homogeneous treatment of both options, all data arrays that are passed through the jVMC API have an additional leading device dimension to account for potential parallelization across devices. The size of this dimension corresponds to the number of devices used by the process and any computation will be distributed among the devices. It is important to realize and keep in mind that when working with multiple devices the device dimension is also physically distributed across the different devices. Hence, any computation on data with device dimension larger than one should be performed on the respective devices to avoid memory transfer overheads.
The default behavior of the JAX library is to allocate all available GPU memory at the beginning of a program. This can lead to conflicts if multiple processes can access the same GPU devices. This issue can be avoided by setting the XLA PYTHON CLIENT PREALLOCATE environment variable prior to running the program as >>> export XLA_PYTHON_CLIENT_PREALLOCATE=false

Batching for vectorization
The computationally intense operation during network evaluations is the matrix-vector product of the affine-linear transformation in Eq. (30), which can be turned into a matrixmatrix product when evaluating the ANN simultaneously on a batch of input configurations. The arithmetic intensity of multiplying a m × n-matrix with a n × k-matrix, i.e., the number of arithmetic operations per memory access, is This contains the arithmetic intensity of a matrix-vector product as the special case with k = 1, which is bounded from above irrespective of the values of m and n: By contrast, I(m, n, k) grows with increasing k (linearly for k ≪ min(m, n)), meaning for the typical affine-linear transformation of ANNs that a suited batching of computational tasks can substantially enhance the arithmetic intensity. Therefore, any operation implemented in jVMC is performed on a batch of input data. This means that following the leading device dimension, all data arrays have an additional batch dimension, which should always be chosen as large as possible given the available memory in order to keep the arithmetic units of the GPU busy.

Core modules
In the following we describe the main functionality of the core modules of the codebase. A detailed documentation is available online [40]. Fig. 2 provides an overview of the main modules. The core functionality is contained in jVMC.vqs, jVMC.sampler, and jVMC.operator. These implement the central operations of typical NQS algorithms in an efficient and flexible manner. The jVMC.util and the jVMC.mpi wrapper modules provide higher-level functionality useful for typical NQS algorithms and functionality for the evaluation of distributed Monte Carlo samples, respectively. The main features of these five modules are described in the following.
A recurring pattern is the possibility to create custom objects, e.g., network architectures or quantum operators. Their behavior has to be defined in the form of functions acting on a computational basis configuration. As a consistent design choice, these functions are always defined acting on a single input configuration. The vectorization for batched evaluations is then taken care of automatically within the respective classes of the jVMC codebase.

Variational wave functions
A core part of the jVMC codebase is the NQS class provided in the jVMC.vqs module -a wrapper class for variational wave functions, which provides an interface that other parts of the code rely on. The NQS class is used to wrap ANNs that are defined using the Flax module system Linen [61]. In this framework ANNs are defined in a very similar manner as in the popular Pytorch or Keras packages as classes that are derived from the flax.linen.Module base class. The minimal requirement for this new class is the implementation of a call member function, which evaluates the neural network on a single input configuration. For example, a Restricted Boltzmann Machine (RBM) with complex parameters, for which   Notice in this example code that the flax.linen package provides basic modules like, e.g., dense layers (flax.linen.Dense), that implement the affine transformation in Eq. (38). Moreover, the flax.linen.compact decorator enables the inlined definition of such submodules, which otherwise have to be defined in a separate setup member function (see Flax documentation for details [61]). In order to work with autoregressive NQS (see Section 2.6.2), a member function called sample has to be implemented as part of the corresponding Flax Linen module. This sample member function has to take the desired number of samples and a jax.random.PRNGKey as input arguments and it should return the resulting configurations.
The jVMC.nets module contains a number of pre-defined common network architectures, including Restricted Boltzmann Machines, Convolutional Neural Networks, and Recurrent Neural Networks. Notice, however, that these classes of architectures delineate general design principles, but leave a lot of freedom regarding the details of the implementation. The jVMC codebase is intended to encourage the exploration of new architectures, for which the provided examples should be regarded as possible starting points.
Once the network architecture is defined in the form of a flax.linen.Module, an instance of it can be initialized. Continuing the example from above, an RBM can be obtained as follows: The NQS class supports the three types of NQS introduced in Section 2.6.1, i.e., a single holomorphic or non-holomorphic ANN or two separate ANNs encoding logarithmic amplitude and phase, respectively. The corresponding ANNs in the form of Linen modules are passed to the NQS class at instantiation. In our example, we define a holomorphic RBM NQS with the following line:

vqs.NQS(net) # Initialize an NQS object
In order to use two separate networks for phase and amplitude, the two networks are passed as a tuple. In addition to this, the constructor takes as optional arguments a batchSize (the meaning of which is explained below) and a seed for the initialization of the network parameters.
The main purpose of the NQS class is to provide an interface for computing wave function coefficients as well as gradients of the variational ansatz. For a given sample of input configurations s the call method yields the logarithmic wave function coefficients log ψ θ (s) and the gradients method returns the logarithmic gradients O θ k (s) = ∂ θ k log ψ θ (s). These methods automatically vectorize the evaluation, meaning that the input requires the additional device and batch dimension as explained in Section 3.1 above. Since the batch-vectorization is limited by the available memory, the batch is prior to evaluation split into mini-batches of size batchSize, which are evaluated simultaneously. The batchSize used by the NQS class is fixed at instantiation (see above) and it should be chosen as large as possible with the given memory resources in order to render the network evaluation compute-bound, see Section 3.1.2.
Finally, the NQS class provides member functions get parameters, set parameters, and update parameters to manipulate the network parameters; see the documentation for details [39].

Operators
The jVMC.operator module implements the evaluation of the action of operators on computational basis configurations. In terms of an operatorÔ acting on the Hilbert space this corresponds to "on-the-fly" generation of matrix elements. Generally, this amounts to a mappingÔ Here, s denotes the input basis configuration and the s ′ j are the basis configurations for which the matrix elements are non-zero, O ss ′ j = ⟨s Ô s ′ j ⟩ ≠ 0. The abstract Operator class defines an interface for the implementation of mappings given in Eq. (39). Any specific operator can be implemented as a child of Operator, which has to implement a compile() member function. Such an implementation should return a JAX-jit-able function (see Appendix A.1) that returns for a given basis configuration s a tuple of two arrays, one containing the configurations s ′ j -the other the corresponding matrix elements O ss ′ j . The following example code shows the implementation of aσ x l operator acting on a chain of spin-1/2 degrees of freedom, i.e., s ∈ {0, 1} N : class SxOperator(jVMC.operator.Operator): """Define a '\hat\sigma_l^x' operator.""" def __init__(self, siteIdx): Any child of the abstract Operator class inherits the following member functions, which employ the operator action defined by the specific compile function: • get s primes(s): Returns s ′ j and O ss ′ j for a batch of input states s. In accordance with Section 3.1 the input dimension is D × B × (spatial dimensions), where D is the device dimension and B is the batch size. The output is a tuple of two arrays of shape D × M × (spatial dimensions) and D × M , where M is the total number of non-zero matrix elements across the whole batch.
• get O loc(logPsiS, logPsiSP): Assuming that get s primes(s) has been called before on a batch of input states s, this function computes the corresponding O θ loc (s) for each element of the batch. As input it requires the logarithmic wave function coefficients log ψ θ (s) (argument logPsiS) and log ψ θ (s ′ ) (argument logPsiSP). It returns an array of size D × B with the complex-valued O θ loc (s).
Thereby, children of the Operator class provide comprehensive functionality to evaluate O θ loc (s) for arbitrary operatorsÔ in an automatically vectorized manner. Consider, for example, that H is an operator object implementing the action of a HamiltonianĤ, psi is the variational wave function, i.e., an instance of the NQS class, and sampleConfigs is a batch of basis configurations. Then, evaluating E θ loc (s) is achieved by the following lines of code: sampleOffdConfigs, matEls = H.get_s_primes(sampleConfigs) sampleLogPsi = psi(sampleConfigs) sampleLogPsiOffd = psi(sampleOffdConfigs) Eloc = H.get_O_loc(sampleLogPsi, sampleLogPsiOffd) Besides the abstract Operator class that serves as a basis to implement arbitrary operators, the jVMC codebase provides two derived classes for typical applications: • BranchFreeOperator: This class provides functionality to compose many-body operators as tensor products of branch-free operators acting on local Hilbert spaces. The term "branch-free" means that the local operators contain only a single nonzero entry in each row/column. An example are the Pauli operators, which are pre-defined as part of the module.
• POVMOperator: The POVMOperator class comprises utility to construct time-evolution operators in the POVM-formalism. Note that the computation of observables differs in this formalism, such that the returned operator may only be interpreted as the object generating the real time-evolution.
The documentation [39] explains the corresponding API in detail.

Sampler
The purpose of the sampler module is to handle all sampling-related tasks in a unified manner. The module provides two classes, the ExactSampler and the MCSampler. The ExactSampler evaluates the network on all the (exponentially many) basis configurations. Obviously, there is no associated computational speedup with this method and its main purpose is to make troubleshooting easier when testing new functionalities. In contrast, the MCSampler may be used to generate MC samples from the Born distribution of ψ θ (s) either by direct sampling from an autoregressive NQS or by Metropolis-Hastings MCMC. The common interface of both sampler classes is a sample member function, which generates a sample in the respective manner and returns a tuple of the generated basis configurations, the corresponding coefficients log ψ θ (s), and the probabilities ψ θ (s) 2 in the case of the ExactSampler or None in the case of the MCSampler. Hence, after instantiating a sampler object called mySampler, generating a set of samples amounts to the line sampleConfigs, sampleLogPsi, sampleProb = mySampler.sample() At instantiation both sampler types need to be passed the variational wave function and the sample shape. The ExactSampler furthermore requires the local Hilbert space dimension lDim.
For the initialization of a MCSampler a jVMC.random.PRNGKey is an additional necessary argument. An updateProposer function needs to be provided for MCMC sampling. The expected signature is updateProposer(key, config, **kwargs) and the function is supposed to return an updated basis configuration that is used as proposed move in the MCMC algorithm. If the optional argument updateProposerArg is given when instantiating a MCSampler, its value is passed to the updateProposer as kwargs. Further initialization arguments are the number of samples to generate numSamples, and for MCMC sampling the number of update proposals per sweep sweepSteps and the number of sweeps used for initial thermalization ("burn-in") thermalizationSweeps. The argument numChains defines the number of MCMC chains that are run in a vectorized manner in order to enhance the computational efficiency.
If the NQS which was given during the initialization of the MCSampler is an autoregressive model, i.e., if a sample member function exists in the respective Linen module, direct sampling is automatically used instead of MCMC sampling.
In both the ExactSampler and the MCSampler the generation of samples is automatically distributed across MPI processes and locally available devices. This means for MC sampling that when running N P MPI processes and calling the sample member function to obtain N MC samples, each of the processes will produce N MC N P samples. Due to the internal vectorization the total number of samples produced by the MCSampler can slightly exceed the number of samples asked for in order to match array dimensions.

MPI wrapper
It's amenability to parallelization across distributed compute clusters is a key feature of VMC algorithms, see Section 3.1. After generating samples locally on the different nodes, the results have to be collected in order to compute the quantities of interest. The mpi wrapper module provides a system for the distributed sampling and the subsequent reduction tasks. For this purpose, it wraps the required MPI communications, for which the mpi4py package [37,38] is used.
The function jVMC.mpi wrapper.distribute sampling distributes sampling tasks across the available processes and devices. For a desired total number of samples this function determines how many samples should be generated by each sampling process. The ExactSampler or MCSampler class invoke this function to automatically distribute the sampling tasks.
Assume that mySampler is an instance of a sampler class, psi is a variational quantum state, and op is an instance of a class derived from the Operator class, associated with a quantum operatorÔ. Further options are get global variance, get global covariance, and get global sum to compute the variance, co-variance matrix, and the sum of all elements, respectively.

Utilities
The jVMC.util module comprises higher-level code which is not at the heart of the jVMC codebase. The code in this module • either contains functionality to combine the workings of the codebase on an intermediate (e.g., jVMC.util.tdvp, jVMC.util.stepper) or higher (e.g., ground state search or network initialization in jVMC.util.util) level, • or solves tasks that are not immediately related to the core functionality (e.g. IOtasks in jVMC.util.output manager or symmetry transformations in jVMC.util.symmetries).
TDVP The jVMC.util.tdvp module is centered around solving the equation where γ ∈ {1, i} is either the real or the imaginary unit. The module comprises a TDVP class with a call member function that returns a (regularized) solutionθ of Eq. (40), which corresponds to an SR step, Eq. (10), for γ = 1 or a real time step following the TDVP equation (19) for γ = i. The signature of the call method matches the standard right hand side required by ordinary differential equation solvers of the scipy library [62] or those provided as part of the jVMC codebase in the jVMC.util.stepper module (see below). Hence, an instance of the TDVP class can be passed as right hand side to these integrators. Invoking the call member function of a TDVP object, essentially amounts to carrying out the steps of Algorithm 1. For solving Eq. (40) the TDVP class offers a number of options, of which we highlight a few, while referring to the detailed documentation [40] for a complete list: • Choose γ by setting the rhsPrefactor at initialization.
• Set the regularization parameters ν, pinv , and SNR described in Section 2.5. The respective regularization scheme is not effective if the corresponding parameter is set to zero.
• The hamiltonian argument may be an Operator object or a function that takes a real number as argument and returns an Operator object. The latter can be used to solve the TDVP equation with time-dependent Hamiltonians.
• We observed that the CUDA implementation of the diagonalization can be unstable if the (quantum) Fisher matrix S is ill-conditioned. To circumvent this issue, we introduced the diagonalizeOnDevice initializer argument, which allows to choose where to perform the diagonalization of the matrix.
• If the call function receives the optional intStep keyword argument, and if intStep==0, various quantities like energy mean and variance, S and F , as well as residuals and the TDVP error of the solution are stored. These quantities can subsequently be retrieved via the corresponding get * member functions of the TDVP class.
Stepper The stepper module comprises two classes for the integration of ordinary differential equations: A simple Euler integrator and an AdaptiveHeun integrator. Both classes contain a step member function as common interface. Assuming that the ODE is given in the standard formẏ(t) = f (y(t), t, p), the step function takes t n , f , y(t n ), and optional additional parameters p as argument and returns y n+1 ≡ y(t n+1 ) following the respective integration scheme. The AdaptiveHeun stepper is a second-order consistent adaptive integrator that adjusts the integration step size τ to achieve a given accuracy step . For this purpose, it computes solutions y (τ ) n+1 and y (τ 2) n+1 using a Runge-Kutta scheme with two different step sizes τ and τ 2. Based on the difference of the solutions δ = y the time step is adjusted according to τ ← τ ( step δ) 1 3 in order to guarantee the desired accuracy.
One important point to observe is that the choice of the normalization measure, which is in principle arbitrary, can have a severe impact on the performance (i.e., the resulting step size). The norm can be chosen by passing a function that computes it as normFunction argument to the step function. A natural choice is to be sensitive only to deviations which correspond to actual changes in the wave-function. This is given, when using the (quantum) Fisher norm Output Manager The output manager handles all IO-related tasks in a unified fashion. For output to a file it relies on the hdf5 data-format which is handled using the h5py library [63]. During evolution, three different types of data are typically stored: • physical observables, corresponding to the write observables function expecting a timestamp and a dictionary of to-be-written observables, • metadata, corresponding to the write metadata function expecting a timestamp and a dictionary of metadata to be written, • network checkpoints, corresponding to the write network checkpoints function expecting a timestamp and the set of network parameters at that time step. This allows to restart the simulation at arbitrary times or to sample observables which have not been stored during the initial run, using the get network checkpoint function to set desired network parameters for all MPI-processes.
Moreover, the OutputManager class provides a print function for printing to the standard output, which blocks all MPI processes except for the root process to avoid repeated output.
Symmetries Incorporating symmetries in the NQS turned out essential for high accuracy in various applications. For the symmetrization a typical task is to produce the symmetry-transformed input configurations s ω(l) for all transformations ω ∈ Ω; see, e.g., Section 5.1 or Refs. [10,17]. The jVMC.util.symmetries module contains functions that return the complete sets of symmetry transformations Ω for simple lattices in the form of permutation matrices. These can be used in definitions of NQS architectures to produce all symmetry-related configurations by simple dot-products, see the code of symmetrized architectures in the jVMC.nets module for examples.

Performance
One of the key features of the jVMC codebase is the efficient exploitation of large scale supercomputing resources for NQS simulations. Ideally, many accelerators (GPUs or TPUs) are available to share the independent workload for each batch of samples as described in Section 3.1. Fig. 3a demonstrates close to ideal speedup achieved by the MPI parallelization of the code. As a benchmark we consider an SR step for a 50-site transverse-field Ising chain with periodic boundary conditions. The NQS architecture is a symmetrized RNN with 1510 variational parameters and the expectation values are estimated using N MC = 3×10 5 Monte Carlo samples. The displayed timings were taken on the JUWELS Booster module at the Jülich Supercomputing Centre [64], where each node is equipped with four NVIDIA A100 GPUs. Using the single-node performance as baseline, we observe almost ideal speedup for up to 16 nodes or 64 GPUs. The breakdown of compute times on a single node in Fig. 3b shows that network evaluations account for the majority of the computational cost, which explains the efficiency of the parallelization. Notice the negligible contribution of the sample generation, which is due to direct sampling from the autoregressive NQS.
Additionally, we benchmark the performance gain obtained when increasing the batch size and, thereby, the arithmetic intensity of the network evaluation tasks. As discussed in Section 3.1.2, the batch size should ideally be chosen to saturate the GPU's memory capacity. To illustrate the significance of suited batching, we again time one SR update step, for this purpose considering a 30-site spin chain and an RNN without symmetries. For N MC = 10 4 , the speedup gained by increasing the batch size is shown in Fig. 3(c)  acceleration of the forward pass by more than a factor of 100 compared to the baseline with batch size 16 underlines the importance of batching for optimal performance. The performance gain by batching is also the reason why running multiple parallel Monte Carlo chains can enhance the performance. However, there exists a sweet spot in the number of chains as shown in Fig. 3(d). For very few MC-chains the GPU is not used to its full capacity as the network is evaluated on too small batches. In the opposite limit, we attribute the deteriorating performance to a saturation of the occupation of the Streaming Multiprocessors such that computations start to be performed sequentially. This leads to effective idle times of the individual chains and therefore to degraded overall performance. Fig. 3(d) shows optimum performance with of the order of 100 parallel Monte Carlo chains, which we found to be a good number in various simulations; here, we used a one-dimensional complex Convolutional Neural Network with an input size of N = 80 and four subsequent layers of sizes 10, 8, 6, 4 with a filter diameter of 20 (6536 variational parameters). The sample size was N MC = 10 4 and the MCMC was performed with sweep size N and 20 initial thermalization sweeps. These simulation parameters are again used for the following test case.
Finally, we include a direct comparison to the NetKet library [65]; notice that NetKet version 3.2, which we are using for this comparison, relies mostly on JAX for the computationally intense parts of the code -this recent development is not described in Ref. [65]. As a benchmark problem we choose the evaluation of an operator expectation value, i.e., the evaluation of Eq. (5), whose cost is comparable to the computation of the energy gradient ∂ θ k E(θ). As operator we use the transverse-field Ising Hamiltonian with system size N = 80. The NQS architecture is the same as in Fig. 3(d) and the sample of size N MC = 10 4 is generated using the approximately optimal number of 350 parallel MCMC chains. We separate the timing into three contributions, (i) the sampling, (ii) the computation of the matrix elements, and (iii) the final evaluation of Eq. (5), which includes the evaluation of ψ(⃗ s) on the involved configurations s and s ′ . The resulting runtimes are shown in Fig. 4(a) for jVMC and in Fig. 4(b) for NetKet. For the given task jVMC is in total about 10% faster than NetKet. The difference between both codes lies in the computation of matrix elements and the final evaluation. To our understanding the difference in runtimes originates in the fact that in jVMC all computational operations are performed on the GPU, while NetKet performs the computation of matrix elements on the CPU. Hence, the parallelism of this task is not fully exploited in NetKet and additional overheads are created due to data transfer between CPU and GPU.

Ground state search
For the exemplary ground state search we consider the one-dimensional transverse-field Ising modelĤ = − lẐ lẐl+1 − g lX l (41) with Pauli operatorsX l andẐ l acting on lattice site l and periodic boundary conditions. Fermionization of the Hamiltonian results in a single particle problem, which can be solved efficiently for large finite systems [66]. In the following, this is our reference for the exact ground state energy.
The un-normalized wave function ansatz is a fully connected single-layer CNN with real parameters, meaning that the NQS automatically incorporates the translational symmetry of the Hamiltonian: Here, T is the set of translations and f ELU (x) denotes the ELU activation function The ground state search is performed by Stochastic Reconfiguration, see Section 2.2, using the learning rate τ = 10 −2 , N MC = 4 × 10 4 MC samples, an initial regularization with diagonal shift ν 0 = 10 that decays with increasing step number n as ν = 0.95 n ν 0 , and a pseudo-inverse cutoff parameter pinv = 10 −8 (see Section 2.5).  This basic ground state search is implemented in the example notebook examples/ex0 ground state search.ipynb, which is contained as an example in the codebase repository and reproducing them is feasible on a single GPU.

Unitary dynamics of pure states
As an example of unitary dynamics we simulate a quench in the transverse-field Ising model on a square latticeĤ whereX j andẐ j are Pauli operators and ⟨i, j⟩ denotes pairs of nearest neighbors on the lattice. The initial state ψ 0 ⟩ is the paramagnetic ground state at J = 0, which we find by an initial ground state search. Subsequently, the system evolves with the Hamiltonian parameters J = 1 g = 1.5g c , where g c = 3.04438J [67] is the critical transverse field strength at which the model exhibits a phase transition. We consider open boundary conditions, meaning that the physical setting is a variant of the situation addressed in [24].
The network architecture of the NQS is a single-layer fully connected complex CNN with suited zero-padding of the computational basis configuration to account for the open boundary conditions. The activation function is the Taylor series of log cosh(x) truncated at sixth order. For a system of 8 × 8 lattice sites we show in Fig. 6a) the time evolution of the transverse magnetization at a central site, ⟨X L 2,L 2 ⟩, obtained with different network sizes to demonstrate convergence, namely 8, 16, 24, and 32 channels in the CNN. Fig. 6b) displays the spatio-temporal spreading of correlations based on the correlation function ⟨Ẑ L 2,L 2Ẑi,j ⟩.
This result was obtained using N MC = 10 6 Monte Carlo samples. The TDVP equation was solved using cutoff parameters pinv = 10 −8 and SNR = 2. The tolerance for the adaptive time step of the integrator was step = 10 −6 .

Dissipative dynamics
Different paths to variational approximations ofρ for open system dynamics using neural networks have been explored [27][28][29], which were briefly introduced in Section 2.4. Here, we showcase results that were obtained in Ref. [29] employing the probability based POVM approach in conjunction with the jVMC codebase. We study a quench of a z-polarized state in a one-dimensional quantum Ising model with transverse and longitudinal field; the openness consists of dephasing with relative strength γ J = 0.25, which eventually drives the system to the featureless steady stateρ SS ∝ 1. The system's Hamiltonian is accordingly given byĤ = and is of interest as it exhibits confinement effects which may be regulated by tuning h z [68]. Non-zero values of h z break the Z 2 symmetry of the TFIM-Hamiltonian and function as an energy penalty for domains that point in opposite direction from h z , thereby suppressing the spreading of correlations. The obtained results are shown in Fig. 7.
Here, an RNN with 5 layers was used where in each layer, latent information is carried in the form of a hidden state of length 12. The corresponding transformations are described by 1456 variational parameters, which are updated using N MC = 1.6 × 10 5 samples in each time step with a prescribed integration tolerance step = 10 −3 and pinv = 10 −8 . The SNR regularization scheme was not used and translational invariance was enforced as described in Sec. 3.2.5.

Conclusion
The jVMC codebase provides a basic and flexible framework for the composition of NQS algorithms. The typical building blocks -on-the-fly computation of operator matrix elements, sampling from the Born distribution, and evaluating the wave function ansatz and its gradients -are implemented in an efficient manner. While designed to fully exploit the resources of distributed compute clusters, the underlying just-in-time compilation allows the user to run the identical Python code just as well on a desktop computer without accelerators.
Compared to the existing NetKet library [65], the jVMC codebase was devised following a different philosophy. While NetKet provides implementations of many different NQS algorithms, network architectures, and physical models on top of the core building blocks, the jVMC codebase focuses on exposing a minimal system of efficiently implemented modules that facilitates the custom composition of such algorithms. This bottom-up ansatz is intended to free future method development from the necessity of first having to create parallelized and performant implementations of the core tasks; at the same time we aimed to impose only minimal bias with respect to the details of resulting algorithms.
The presented codebase has already been used in research applications to compute the energy gap of a two-dimensional quantum Ising model and its unitary dynamics across a quantum phase transition [25], and also for the simulation of open quantum spin systems in one and two dimensions [29]. We expect that it will serve as a useful basis for further method development and applications in these and other directions, where VMC with NQS can potentially push our classical simulation capabilities.
behavior of JAX-transformed functions. Some workarounds are needed to reconcile JAX with object-oriented code.
• Control flow needs some care. Replace Python for-loops or if-branching with primitives from the jax.lax sub-module, see the JAX documentation for details.
• JAX arrays are immutable in order to enable function tracing. However, when attempting to change the content of a JAX array, the thrown error gives hints to the correct pure analog of this operation. For a comprehensive list, we refer the reader to the section about "the sharp bits" of the JAX documentation [70]. The example codes below are variations of examples shown in the JAX documentation.

A.1 Just-in-time compilation
jax.jit returns a compiled version of the function that is given as the argument. Importantly, the function will be recompiled whenever the input data type or shape is changed. When running this piece of code, the function will first be compiled when the last line is executed and then the compiled function will be called. In this example, calling the jitted function instead of the original Python function will not exhibit any significant differences (besides a compilation overhead when calling the jitted function for the first time). Compiling functions that execute considerable computational workload, instead, can yield significant gains in performance. jax.jit can be used to showcase an example of unexpected behavior when global variables are involved in JAX-transformed functions: >>> global g >>> g = 3 >>> def f(x): ... return x + g >>> f_jit = jax.jit(f) >>> f_jit(0) DeviceArray(3, dtype=int32) >>> g = 4 >>> f_jit(0) DeviceArray(3, dtype=int32) >>> f_jit(0.) DeviceArray(4., dtype=float32) As mentioned above, a jitted function is recompiled if invoked with new data types, which is the reason for the output of the last line.

A.2 Automatic differentiation
The restriction to pure functions is the basis for an easily accessible implementation of automatic differentiation in JAX. The derivative of arbitrary pure functions can be computed with jax.grad. Moreover, the function returned by jax.grad can be used in further transformations, for example to compute the second derivative in the code below. If derivatives for more than the first argument are required, the keyword-argument argnums needs to be specified.

A.3 Vectorization
To accelerate typically slow Python for-loops and in order to enable the full exploitation of accelerators, JAX comprises can transform a single-instruction single-data (SISD) operation (like the function f defined above) into a function that works on multiple data (single-instruction multiple-data, SIMD). For a given function that operates on some input data, jax.vmap returns a vectorized version of that function, that applies the original function element-wise to vectors of input data. As mentioned above, all JAX function transformation are composable, meaning that one can arbitrarily combine jax.jit, jax.grad, and jax.vmap as demonstrated in the following example:

A.4 Parallelization across multiple accelerators
When multiple accelerators are available, jax.pmap can map a function to these devices, such that the function is executed in parallel with the elements along one axis of the input data as arguments. Thereby, jax.pmap is superficially very similar to jax.vmap. But, in fact, the effect of both transformations is quite different: While jax.vmap vectorizes a function, jax.pmap creates copies of the function which are then carried out on each device individually. This means, for example, that the data that is passed to a pmapped function has to be physically scattered across different devices.
There are some pitfalls that one needs to be aware of; one example is that the leading axis dimension of inputs must not be greater than the number of devices. A practical way of circumventing such issues and to obtain good parallelization results is to reshape the leading axis of datasets into the number of devices, which can be obtained using jax.local device count(). In the example below we assume that two GPUs are available.