#!/bin/bash -e

wd=$(pwd)
jobid=$(squeue --me | head -3 | tail -n1 | awk '{print $1}')


module purge
module load CrayEnv
module load PrgEnv-cray/8.5.0
module load craype-x86-trento
module load craype-accel-amd-gfx90a
module load cray-python

module use module use /appl/local/containers/test-modules
module load rocm/6.1.3.lua omniperf/2.1.0

# Example assume allocation was created, e.g.:
# N=1 ; salloc -p standard-g  --threads-per-core 1 --exclusive -N $N --gpus $((N*8)) -t 4:00:00 --mem 0
#

set -x

SIF=/appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif

# Omniperf requires some python dependencies.
# Let's install them - match omniperf version.
if [ ! -d $wd/omniperf-venv ] ; then
    srun -N1 -n1 --gpus 8 --cpu-bind=none \
        singularity exec \
        -B /var/spool/slurmd \
        -B /opt/cray \
        -B /usr/lib64/libcxi.so.1 \
        -B $wd:/workdir \
        -B $OMNIPERF_DIR \
        -B /usr/lib64/libpciaccess.so.0 \
        $SIF bash -c '$WITH_CONDA ; \
                    unset PYTHONPATH ; \
                    cd /workdir ; \
                    curl -LO https://raw.githubusercontent.com/ROCm/omniperf/refs/tags/v2.1.0/requirements.txt ; \
                    python -m venv --system-site-packages omniperf-venv ; \
                    source omniperf-venv/bin/activate ; \
                    pip install kaleido==0.2.1 ; \
                    pip install -r requirements.txt'

fi

rm -rf $wd/run-me.sh 
cat > $wd/run-me.sh << EOF
#!/bin/bash -e

# Make sure GPUs are up
if [ \$SLURM_LOCALID -eq 0 ] ; then
    rocm-smi
fi
sleep 2

export MIOPEN_USER_DB_PATH="/tmp/$(whoami)-miopen-cache-\$SLURM_NODEID"
export MIOPEN_CUSTOM_CACHE_DIR=\$MIOPEN_USER_DB_PATH

# Set MIOpen cache to a temporary folder.
if [ \$SLURM_LOCALID -eq 0 ] ; then
    rm -rf \$MIOPEN_USER_DB_PATH
    mkdir -p \$MIOPEN_USER_DB_PATH
fi
sleep 2
  
# Report affinity
echo "Rank \$SLURM_PROCID --> \$(taskset -p \$\$)"

# Start conda environment inside the container
\$WITH_CONDA

# Add omnitrace environment
export PATH=$OMNIPERF_DIR/bin:\$PATH
# Installed requirements:
export PYTHONPATH=/workdir/omniperf-venv/lib64/python3.12/site-packages

# Set interfaces to be used by RCCL.
export NCCL_SOCKET_IFNAME=hsn0,hsn1,hsn2,hsn3
export NCCL_NET_GDR_LEVEL=PHB

# Set environment for the app
export MASTER_PORT=29500
export WORLD_SIZE=\$SLURM_NPROCS
export RANK=\$SLURM_PROCID
export HIP_VISIBLE_DEVICES=\$SLURM_LOCALID

# Run app
cd /workdir/mnist

pcmd=''
if [ \$RANK -eq 0 ] ; then
    rm -rf /workdir/mnist/workloads
    ROCPROF=rocprofv2 \
    omniperf profile -n myprof --device 0 --roof-only -- \$(which python) -u mnist_DDP.py --gpu --modelpath /workdir/mnist/model
else 
    for i in {1..4} ; do 
        #if [ \$i -eq 1 ] ; then sleep 20 ; else sleep 5 ; fi
        python -u mnist_DDP.py --gpu --modelpath /workdir/mnist/model
    done
fi

EOF
chmod +x $wd/run-me.sh

c=fe
MYMASKS="0x${c}000000000000,0x${c}00000000000000,0x${c}0000,0x${c}000000,0x${c},0x${c}00,0x${c}00000000,0x${c}0000000000"

Nodes=1

export MASTER_ADDR=$(scontrol show hostname "$SLURM_NODELIST" | head -n1)

srun --jobid=$jobid -N $((Nodes)) -n $((Nodes*8)) --gpus $((Nodes*8)) --cpu-bind=mask_cpu:$MYMASKS \
  singularity exec \
    -B /var/spool/slurmd \
    -B /opt/cray \
    -B /usr/lib64/libcxi.so.1 \
    -B $wd:/workdir \
    -B $OMNIPERF_DIR \
    -B /usr/lib64/libpciaccess.so.0 \
    $SIF /workdir/run-me.sh
   

exit 0

################
# Analyse with #
################
module use /appl/local/containers/test-modules
module load cray-python rocm/6.1.3.lua omniperf/2.1.0
cd $wd/mnist
omniperf analyze -p workloads/myprof/MI200 --gui 12345
