#!/bin/bash -e

wd=$(pwd)
jobid=$(squeue --me | head -2 | tail -n1 | awk '{print $1}')

#
# Create omnitrace configuration.
#
module purge
module load CrayEnv
module load PrgEnv-cray/8.5.0
module load craype-x86-trento
module load craype-accel-amd-gfx90a
module load cray-python

module use module use /appl/local/containers/test-modules
module load rocm/6.1.3.lua omnitrace/1.12.0-rocm6.1.x

rm -rf omnitrace-config.cfg
srun -N 1 -n 1 --gpus 8 --jobid=$jobid omnitrace-avail -a -G 

sed -i 's/OMNITRACE_USE_SAMPLING.*/OMNITRACE_USE_SAMPLING = true/g' $wd/omnitrace-config.cfg
sed -i 's/OMNITRACE_USE_ROCM_SMI.*/OMNITRACE_USE_ROCM_SMI = false/g' $wd/omnitrace-config.cfg
sed -i 's/OMNITRACE_SAMPLING_CPUS.*/OMNITRACE_SAMPLING_CPUS = none/g' $wd/omnitrace-config.cfg
sed -i 's/OMNITRACE_SAMPLING_GPUS.*/OMNITRACE_SAMPLING_GPUS = 2/g' $wd/omnitrace-config.cfg

#
# Example assume allocation was created, e.g.:
# N=1 ; salloc -p standard-g  --threads-per-core 1 --exclusive -N $N --gpus $((N*8)) -t 4:00:00 --mem 0
#

set -x

SIF=/appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif

rm -rf $wd/run-me.sh 
cat > $wd/run-me.sh << EOF
#!/bin/bash -e

# Make sure GPUs are up
if [ \$SLURM_LOCALID -eq 0 ] ; then
    rocm-smi
fi
sleep 2

export MIOPEN_USER_DB_PATH="/tmp/$(whoami)-miopen-cache-\$SLURM_NODEID"
export MIOPEN_CUSTOM_CACHE_DIR=\$MIOPEN_USER_DB_PATH

# Set MIOpen cache to a temporary folder.
if [ \$SLURM_LOCALID -eq 0 ] ; then
    rm -rf \$MIOPEN_USER_DB_PATH
    mkdir -p \$MIOPEN_USER_DB_PATH
fi
sleep 2
  
# Report affinity
echo "Rank \$SLURM_PROCID --> \$(taskset -p \$\$)"

# Start conda environment inside the container
\$WITH_CONDA

# Add omnitrace environment
export PATH=$omnitrace_ROOT/bin:\$PATH
export LD_LIBRARY_PATH=$omnitrace_ROOT/lib:\$LD_LIBRARY_PATH
export PYTHONPATH=$omnitrace_ROOT/lib/python/site-packages:\$LD_LIBRARY_PATH
export OMNITRACE_CONFIG_FILE=/workdir/omnitrace-config.cfg

# Set interfaces to be used by RCCL.
export NCCL_SOCKET_IFNAME=hsn0,hsn1,hsn2,hsn3
export NCCL_NET_GDR_LEVEL=PHB

# Set environment for the app
export MASTER_PORT=29500
export WORLD_SIZE=\$SLURM_NPROCS
export RANK=\$SLURM_PROCID
export HIP_VISIBLE_DEVICES=\$SLURM_LOCALID

# Run app
cd /workdir/mnist

pcmd=''
if [ \$RANK -eq 2 ] ; then
   pcmd='omnitrace-sample -- ' 
fi

\$pcmd python -u mnist_DDP.py --gpu --modelpath /workdir/mnist/model

EOF
chmod +x $wd/run-me.sh

c=fe
MYMASKS="0x${c}000000000000,0x${c}00000000000000,0x${c}0000,0x${c}000000,0x${c},0x${c}00,0x${c}00000000,0x${c}0000000000"

Nodes=2

export MASTER_ADDR=$(scontrol show hostname "$SLURM_NODELIST" | head -n1)

srun --jobid=$jobid -N $((Nodes)) -n $((Nodes*8)) --gpus $((Nodes*8)) --cpu-bind=mask_cpu:$MYMASKS \
  singularity exec \
    -B /var/spool/slurmd \
    -B /opt/cray \
    -B /usr/lib64/libcxi.so.1 \
    -B $wd:/workdir \
    -B $omnitrace_ROOT \
    -B /usr/lib64/libpciaccess.so.0 \
    $SIF /workdir/run-me.sh
   