Skip to content

Commit

Permalink
Some progress with jax-cpu-mpi
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrey Latyshev committed Jan 19, 2025
1 parent 1256421 commit e71a022
Show file tree
Hide file tree
Showing 5 changed files with 1,029 additions and 23 deletions.
244 changes: 244 additions & 0 deletions jax-gpu/constitutive_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
from mpi4py import MPI
from petsc4py import PETSc

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np

E = 6778 # [MPa] Young modulus
nu = 0.25 # [-] Poisson ratio
c = 3.45 # [MPa] cohesion
phi = 30 * np.pi / 180 # [rad] friction angle
psi = 30 * np.pi / 180 # [rad] dilatancy angle
theta_T = 26 * np.pi / 180 # [rad] transition angle as defined by Abbo and Sloan
a = 0.26 * c / np.tan(phi) # [MPa] tension cuff-off parameter
stress_dim = 4

def J3(s):
return s[2] * (s[0] * s[1] - s[3] * s[3] / 2.0)


def J2(s):
return 0.5 * jnp.vdot(s, s)


def theta(s):
J2_ = J2(s)
arg = -(3.0 * np.sqrt(3.0) * J3(s)) / (2.0 * jnp.sqrt(J2_ * J2_ * J2_))
arg = jnp.clip(arg, -1.0, 1.0)
theta = 1.0 / 3.0 * jnp.arcsin(arg)
return theta


def sign(x):
return jax.lax.cond(x < 0.0, lambda x: -1, lambda x: 1, x)


def coeff1(theta, angle):
return np.cos(theta_T) - (1.0 / np.sqrt(3.0)) * np.sin(angle) * np.sin(theta_T)


def coeff2(theta, angle):
return sign(theta) * np.sin(theta_T) + (1.0 / np.sqrt(3.0)) * np.sin(angle) * np.cos(theta_T)


coeff3 = 18.0 * np.cos(3.0 * theta_T) * np.cos(3.0 * theta_T) * np.cos(3.0 * theta_T)


def C(theta, angle):
return (
-np.cos(3.0 * theta_T) * coeff1(theta, angle) - 3.0 * sign(theta) * np.sin(3.0 * theta_T) * coeff2(theta, angle)
) / coeff3

def B(theta, angle):
return (
sign(theta) * np.sin(6.0 * theta_T) * coeff1(theta, angle) - 6.0 * np.cos(6.0 * theta_T) * coeff2(theta, angle)
) / coeff3


def A(theta, angle):
return (
-(1.0 / np.sqrt(3.0)) * np.sin(angle) * sign(theta) * np.sin(theta_T)
- B(theta, angle) * sign(theta) * np.sin(3 * theta_T)
- C(theta, angle) * np.sin(3.0 * theta_T) * np.sin(3.0 * theta_T)
+ np.cos(theta_T)
)


def K(theta, angle):
def K_false(theta):
return jnp.cos(theta) - (1.0 / np.sqrt(3.0)) * np.sin(angle) * jnp.sin(theta)

def K_true(theta):
return (
A(theta, angle)
+ B(theta, angle) * jnp.sin(3.0 * theta)
+ C(theta, angle) * jnp.sin(3.0 * theta) * jnp.sin(3.0 * theta)
)

return jax.lax.cond(jnp.abs(theta) > theta_T, K_true, K_false, theta)

def a_g(angle):
return a * np.tan(phi) / np.tan(angle)

dev = np.array(
[
[2.0 / 3.0, -1.0 / 3.0, -1.0 / 3.0, 0.0],
[-1.0 / 3.0, 2.0 / 3.0, -1.0 / 3.0, 0.0],
[-1.0 / 3.0, -1.0 / 3.0, 2.0 / 3.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
],
dtype=PETSc.ScalarType,
)
tr = np.array([1.0, 1.0, 1.0, 0.0], dtype=PETSc.ScalarType)


def surface(sigma_local, angle):
s = dev @ sigma_local
I1 = tr @ sigma_local
theta_ = theta(s)
return (
(I1 / 3.0 * np.sin(angle))
+ jnp.sqrt(
J2(s) * K(theta_, angle) * K(theta_, angle) + a_g(angle) * a_g(angle) * np.sin(angle) * np.sin(angle)
)
- c * np.cos(angle)
)

def f(sigma_local):
return surface(sigma_local, phi)

def g(sigma_local):
return surface(sigma_local, psi)

dgdsigma = jax.jacfwd(g)

lmbda = E * nu / ((1.0 + nu) * (1.0 - 2.0 * nu))
mu = E / (2.0 * (1.0 + nu))
C_elas = np.array(
[
[lmbda + 2 * mu, lmbda, lmbda, 0],
[lmbda, lmbda + 2 * mu, lmbda, 0],
[lmbda, lmbda, lmbda + 2 * mu, 0],
[0, 0, 0, 2 * mu],
],
dtype=PETSc.ScalarType,
)
S_elas = np.linalg.inv(C_elas)
ZERO_VECTOR = np.zeros(stress_dim, dtype=PETSc.ScalarType)

def deps_p(sigma_local, dlambda, deps_local, sigma_n_local):
sigma_elas_local = sigma_n_local + C_elas @ deps_local
yielding = f(sigma_elas_local)

def deps_p_elastic(sigma_local, dlambda):
return ZERO_VECTOR

def deps_p_plastic(sigma_local, dlambda):
return dlambda * dgdsigma(sigma_local)

return jax.lax.cond(yielding <= 0.0, deps_p_elastic, deps_p_plastic, sigma_local, dlambda)


def r_g(sigma_local, dlambda, deps_local, sigma_n_local):
deps_p_local = deps_p(sigma_local, dlambda, deps_local, sigma_n_local)
return sigma_local - sigma_n_local - C_elas @ (deps_local - deps_p_local)


def r_f(sigma_local, dlambda, deps_local, sigma_n_local):
sigma_elas_local = sigma_n_local + C_elas @ deps_local
yielding = f(sigma_elas_local)

def r_f_elastic(sigma_local, dlambda):
return dlambda

def r_f_plastic(sigma_local, dlambda):
return f(sigma_local)

return jax.lax.cond(yielding <= 0.0, r_f_elastic, r_f_plastic, sigma_local, dlambda)


def r(y_local, deps_local, sigma_n_local):
sigma_local = y_local[:stress_dim]
dlambda_local = y_local[-1]

res_g = r_g(sigma_local, dlambda_local, deps_local, sigma_n_local)
res_f = r_f(sigma_local, dlambda_local, deps_local, sigma_n_local)

res = jnp.c_["0,1,-1", res_g, res_f] # concatenates an array and a scalar
return res

drdy = jax.jacfwd(r)

Nitermax, tol = 200, 1e-10

ZERO_SCALAR = np.array([0.0])


def return_mapping(deps_local, sigma_n_local):
"""Performs the return-mapping procedure.
It solves elastoplastic constitutive equations numerically by applying the
Newton method in a single Gauss point. The Newton loop is implement via
`jax.lax.while_loop`.
The function returns `sigma_local` two times to reuse its values after
differentiation, i.e. as once we apply
`jax.jacfwd(return_mapping, has_aux=True)` the ouput function will
have an output of
`(C_tang_local, (sigma_local, niter_total, yielding, norm_res, dlambda))`.
Returns:
sigma_local: The stress at the current Gauss point.
niter_total: The total number of iterations.
yielding: The value of the yield function.
norm_res: The norm of the residuals.
dlambda: The value of the plastic multiplier.
"""
niter = 0

dlambda = ZERO_SCALAR
sigma_local = sigma_n_local
y_local = jnp.concatenate([sigma_local, dlambda])

res = r(y_local, deps_local, sigma_n_local)
norm_res0 = jnp.linalg.norm(res)

def cond_fun(state):
norm_res, niter, _ = state
return jnp.logical_and(norm_res / norm_res0 > tol, niter < Nitermax)

def body_fun(state):
norm_res, niter, history = state

y_local, deps_local, sigma_n_local, res = history

j = drdy(y_local, deps_local, sigma_n_local)
j_inv_vp = jnp.linalg.solve(j, -res)
y_local = y_local + j_inv_vp

res = r(y_local, deps_local, sigma_n_local)
norm_res = jnp.linalg.norm(res)
history = y_local, deps_local, sigma_n_local, res

niter += 1

return (norm_res, niter, history)

history = (y_local, deps_local, sigma_n_local, res)

norm_res, niter_total, y_local = jax.lax.while_loop(cond_fun, body_fun, (norm_res0, niter, history))

sigma_local = y_local[0][:stress_dim]
dlambda = y_local[0][-1]
sigma_elas_local = C_elas @ deps_local
yielding = f(sigma_n_local + sigma_elas_local)

return sigma_local, (sigma_local, niter_total, yielding, norm_res, dlambda)

def constitutive_response(sigma_local, sigma_n_local):
deps_elas = S_elas @ sigma_local
sigma_corrected, state = return_mapping(deps_elas, sigma_n_local)
yielding = state[2]
return sigma_corrected, yielding
46 changes: 26 additions & 20 deletions jax-gpu/jax-gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,22 @@
from jax.sharding import PartitionSpec as P
from jax._src import distributed

from dolfinx import mesh, fem
import basix

jax.distributed.initialize()
print(f"Backend: {jax.default_backend()}")
print(f"Global devices: {jax.devices()}\n")
print(f"Local devices: {jax.local_devices()}\n")
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
local_device_ids = [int(i) for i in cuda_visible_devices.split(",")]
print(f"CUDA_VISIBLE_DEVICES = {local_device_ids} ")
jax.distributed.initialize(local_device_ids=local_device_ids)
local_device_ids = [int(i) for i in cuda_visible_devices.split(",")]
print(distributed.global_state.client)
print(distributed.global_state.service)
print(distributed.global_state.process_id)
# import os
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
# jax.distributed.initialize()
print(f"Backend: {jax.default_backend()}")
print(f"Global devices: {jax.devices()}\n")
print(f"Local devices: {jax.local_devices()}\n")

# A = jax.random.uniform(jax.random.key(0), (3,3,3), dtype=jnp.float64)
# A_sym = 0.5 * (A + A.T)
# result = jax.lax.linalg.eigh(A, lower=True, symmetrize_input=True,
# sort_eigenvalues=True, subset_by_index=None)

# LU_vec = jax.jit(jax.vmap(jax.lax.linalg.lu, in_axes=(0)))
# result = LU_vec(A_sym)
# print(result)


E = 6778 # [MPa] Young modulus
nu = 0.25 # [-] Poisson ratio
Expand All @@ -51,21 +44,34 @@
],
dtype=np.float64,
)
tr = np.array([1.0, 1.0, 1.0, 0.0], dtype=PETSc.ScalarType)

def f(eps):
return C_elas @ eps
"""Function for benchmarking"""
return tr @ C_elas @ eps

f_vec = jax.vmap(f, in_axes=(0))
f_vec_jit = jax.jit(f_vec)

N = 10
domain = mesh.create_unit_square(MPI.COMM_WORLD, N, N, mesh.CellType.triangle)
Q_element = basix.ufl.quadrature_element(domain.topology.cell_name(), degree=1, value_shape=())
Q = fem.functionspace(domain, Q_element)
scale_var = fem.Function(Q)

if MPI.COMM_WORLD.rank == 0:
print(f"rank = {MPI.COMM_WORLD.rank} Globally: #DoFs(Q): {Q.dofmap.index_map.size_global:6d}\n", flush=True)

print(f"rank = {MPI.COMM_WORLD.rank} Locally: #DoFs(V_alpha): {Q.dofmap.index_map.size_local:6d} scale_var {scale_var.x.array.shape}", flush=True)

# def f(x): # function we're benchmarking (works in both NumPy & JAX)
# return x.T @ (x - x.mean(axis=0))

f_vec_jit = jax.jit(f_vec)

N_list = 12*np.array([10, 100, 1000, 10000, 100000, 1000000])

def benchmarking(sharding=None):
def measurement(N, sharding=sharding):
eps_np = np.ones((N, 4), dtype=np.float64) # same as JAX default dtype
# x_np = np.ones((N, N), dtype=np.float64) # same as JAX default dtype
eps_np = np.ones((N, 4), dtype=np.float64)

time_numpy = timeit(lambda: f_vec(eps_np), number=1)

Expand Down
6 changes: 3 additions & 3 deletions jax-gpu/jax-gpu.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash -l
#SBATCH --nodes=2
#SBATCH -G 8
#SBATCH -c 28
#SBATCH --nodes=1
#SBATCH -G 4
#SBATCH -c 12
#SBATCH --partition=gpu

##SBATCH --ntasks-per-node=28
Expand Down
Loading

0 comments on commit e71a022

Please sign in to comment.