#!/bin/bash -e

wd=$(pwd)
jobid=$(squeue --me | head -2 | tail -n1 | awk '{print $1}')


#
# Example assume allocation was created, e.g.:
# N=1 ; salloc -p standard-g  --threads-per-core 1 --exclusive -N $N --gpus $((N*8)) -t 4:00:00 --mem 0
#

module purge
module load CrayEnv
module load PrgEnv-cray/8.5.0
module load craype-x86-trento
module load craype-accel-amd-gfx90a
module load cray-python

module use module use /appl/local/containers/test-modules
module load rocm/6.1.3.lua

set -x

# Install pytorch if it doesn't exist already.
if [ ! -d $wd/miniconda3/envs/pytorch ] ; then
    curl -LO https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
    bash ./Miniconda3-* -b -p $wd/miniconda3 -s
    rm -rf ./Miniconda3-*

    sed -i 's/defaults/conda-forge/g' $wd/miniconda3/.condarc

    source $wd/miniconda3/bin/activate base
    conda create -y -n pytorch python=3.11
    source $wd/miniconda3/bin/activate pytorch
    pip3 install --pre torch==2.4.1+rocm6.1 --index-url https://download.pytorch.org/whl/
else
    source $wd/miniconda3/bin/activate pytorch
fi

srun --jobid=$jobid -n1 --gpus 8 \
    python -c 'import torch; print("I have this many devices:", torch.cuda.device_count())'
