jax
flax>=0.7.1
transformer_engine_cu12==2.14.0
