#!/bin/bash -e

wd=$(pwd)
jobid=$(squeue --me | head -2 | tail -n1 | awk '{print $1}')


#
# Example assume allocation was created, e.g.:
# N=1 ; salloc -p standard-g  --threads-per-core 1 --exclusive -N $N --gpus $((N*8)) -t 4:00:00 --mem 0
#

module purge
module load CrayEnv
module load PrgEnv-cray/8.5.0
module load craype-x86-trento
module load craype-accel-amd-gfx90a
module load cray-python

module use module use /appl/local/containers/test-modules
module load rocm/6.1.3.lua

set -x

# Install pytorch if it doesn't exist already.
if [ ! -d $wd/cray-python-virtualenv ] ; then
    python -m venv --system-site-packages cray-python-virtualenv
    source cray-python-virtualenv/bin/activate
    pip3 install --pre torch==2.4.1+rocm6.1 --index-url https://download.pytorch.org/whl/
else
    source cray-python-virtualenv/bin/activate
fi

srun --jobid=$jobid -n1 --gpus 8 \
    python -c 'import torch; print("I have this many devices:", torch.cuda.device_count())'
