Software Support
Current software support mainly targets AI/ML workloads,
and we are welcome contribution on testing neutrino
on other frameworks and workloads.
Support Matrix are summarized below
Framework | Status |
---|---|
cuBLAS/cuFFT/cuSparse... | ❌ (no plan for supporting) |
CUTLASS | ✅ (with macro in building) |
PyTorch | ✅ (with manual building) |
JAX | ✅ (with envriable) |
Triton | ✅ |
Taichi | ✅ |
Due to the uniqueness of neutrino
, some special arangements might be applied for correct functioning for each framework:
cuBLAS/cuDNN
neutrino
does not support these NVIDIA propietary product for several reason:
- NVIDIA updates its EULA on decompile/disassemble these propietary products.
- These propietary product heavily used dark apis, which is out of the scope.
- Even observation is made, optimization by developers are impossible as they are closed source.
Unfortunately, some drawbacks from not supporting cuBLAS/cuFFT:
- PyTorch's
nn.Linear
and other matmul / conv operations can not be traced -> consider usingCUTLASS
instead.
PyTorch
Support for PyTorch requires modifying a line in its CMakeLists.txt
and for simplicity, we provide pre-built wheels hosted in Cloudfare R2:
- SM_75: https://pub-eef24bf0aa5b4950860ea28dfbe39d8c.r2.dev/sm_75/torch/torch-2.5.0-cp311-cp311-linux_x86_64.whl
- SM_80: https://pub-eef24bf0aa5b4950860ea28dfbe39d8c.r2.dev/sm_80/torch/torch-2.5.0-cp311-cp311-linux_x86_64.whl
- SM_86: https://pub-eef24bf0aa5b4950860ea28dfbe39d8c.r2.dev/sm_86/torch/torch-2.5.0-cp311-cp311-linux_x86_64.whl
- SM_89: https://pub-eef24bf0aa5b4950860ea28dfbe39d8c.r2.dev/sm_89/torch/torch-2.5.0-cp311-cp311-linux_x86_64.whl
- SM_90
- GFX900x (AMD CDNA)
- GFX1000x (AMD RDNA)
PYPI
Currently links are anonymoused as we're in Artifact Evaluation.
We are working on maintaining a pip source for better user experience. Stay tuned!
Support for PyTorch requries manual building to store PTX Assembly in installation (by default, PyTorch keeps only SASS):
- Clone the PyTorch:
git clone --recursive https://github.com/pytorch/pytorch
, add--branch
to specify branch if need - Following the guide to install dependnecies.
- Query compute capability via
nvidia-smi --query-gpu=compute_cap --format=csv,noheader
- Modify the fatbin setting and add NVCC flags in
pytorch/CMakeLists.txt
, see below code block. - Follow the guide to build and install PyTorch with
+PTX
in TORCH_CUDA_ARCH_LIST likeTORCH_CUDA_ARCH_LIST="8.0+PTX"
.
string(APPEND CMAKE_CUDA_FLAGS " -Xfatbin -compress-all") // [\!code --]
string(APPEND CMAKE_CUDA_FLAGS " -Xfatbin --compress=false") // [\!code ++]
JAX/XLA
Simply add an envrironment variable when executing the XLA program:
XLA_FLAGS=--xla_gpu_generate_line_info='true' python jax_workload.py