TorchSharp.FlashAttention-windows 0.4.0

dotnet add package TorchSharp.FlashAttention-windows --version 0.4.0                
NuGet\Install-Package TorchSharp.FlashAttention-windows -Version 0.4.0                
This command is intended to be used within the Package Manager Console in Visual Studio, as it uses the NuGet module's version of Install-Package.
<PackageReference Include="TorchSharp.FlashAttention-windows" Version="0.4.0" />                
For projects that support PackageReference, copy this XML node into the project file to reference the package.
paket add TorchSharp.FlashAttention-windows --version 0.4.0                
#r "nuget: TorchSharp.FlashAttention-windows, 0.4.0"                
#r directive can be used in F# Interactive and Polyglot Notebooks. Copy this into the interactive tool or source code of the script to reference the package.
// Install TorchSharp.FlashAttention-windows as a Cake Addin
#addin nuget:?package=TorchSharp.FlashAttention-windows&version=0.4.0

// Install TorchSharp.FlashAttention-windows as a Cake Tool
#tool nuget:?package=TorchSharp.FlashAttention-windows&version=0.4.0                

TorchSharp.FlashAttention

TorchSharp.FlashAttention-windows TorchSharp.FlashAttention-linux

Introduction

TorchSharp.FlashAttention is a C# wrapper for the Flash Attention algorithm, leveraging the capabilities of TorchSharp for efficient deep learning in .NET environments. The Flash Attention algorithm, developed by Dao-AILab, is a groundbreaking method for accelerating attention computation in Transformer models. It significantly reduces memory usage and computation time, enabling faster and more efficient processing of large-scale data, especially in natural language processing and computer vision tasks.

Installation from NuGet

TorchSharp.FlashAttention is available on NuGet. Due to the size of the binaries, we split it into two packages for windows (TorchSharp.FlashAttention-windows) and linux (TorchSharp.FlashAttention-linux). The packages already come with the precompiled binaries, so you don't need to install anything (except the preprequisites).

Prerequisites

  • .NET SDK 6.0+
  • TorchSharp-cuda-windows or TorchSharp-cuda-linux package

Compatibility:

  • For TorchSharp version 0.102.x, use TorchSharp.FlashAttention version <= 0.2.2
  • For TorchSharp version >= 0.103.x and <= 0.104.x, use TorchSharp.FlashAttention version == 0.3.0
  • For TorchSharp version >= 0.105.x, use TorchSharp.FlashAttention version >= 0.4.0

For building from source, see below.

Usage

All the attention-related functions in the flash_attn package have been ported over, and can be used. The rest of the flash-attention operations (like FusedDense, LayerNorm, etc.) are going to be added into future versions.

For each function, we allow all the parameters that are accessible through the Python interface.

The package currently references FlashAttention 2.5.5, including AliBi embeddings and forward with KV cache.

The interfaces that have been ported over:

// FlashAttentionInterface
FlashAttentionInterface.flash_attn_func(...);
FlashAttentionInterface.flash_attn_kvpacked_func(...);
FlashAttentionInterface.flash_attn_qkvpacked_func(...);
FlashAttentionInterface.flash_attn_varlen_func(...);
FlashAttentionInterface.flash_attn_varlen_kvpacked_func(...);
FlashAttentionInterface.flash_attn_varlen_qkvpacked_func(...);
FlashAttentionInterface.flash_attn_with_kvcache(...);

// BertPadding
BertPadding.pad_input(...);
BertPadding.unpad_input(...);
BertPadding.unpad_input_for_concatenated_sequences(...);

Example using the interface

Here is a simple example for using the FlashAttention interface in C#:

using TorchSharp.FlashAttention;

var (batch_size, seqlen, headdim, nheads) = (5, 12, 32, 4)
var qkv = torch.rand([batch_size, seqlen, 3, nheads, headdim]).half().cuda();
var (result, _, _) = FlashAttentionInterface.flash_attn_qkvpacked_func(qkv);

Comparison to Python:

from flash_attn import flash_attn_qkvpacked_func

batch_size, seqlen, headdim, heads = 5, 12, 32, 4
qkv = torch.rand(batch_size, seqlen, 3, heads, headdim).half().cuda()
res = flash_attn_qkvpacked_func(qkv)

Example using the FlashAttention module

In addition to the interface, we also include a custom TorchSharp module for applying QKV packed attention, with a custom key-padding mask:

using TorchSharp.FlashAttention;
using TorchSharp;


class MyModule<torch.Tensor, torch.Tensor, torch.Tensor> {
    private FlashAttention _flash;
    // ... rest of your fields

    public MyModule() : base("MyModule) {
        _flash = new(softmax_scale: 1, attention_dropout: 0.1, causal: true);
        // ... rest of your module

        RegisterComponents();
    }


    public override Tensor forward(Tensor input, Tensor key_padding_mask) {
        // ...
        var attnOutput = _flash.forward(input, key_padding_mask);
        // ...
    }
}

Building from Source

There are multiple steps to building from source - we need to compile the cuda binaries, build the Native library bindings, and then build the C# library.

Compiling the cuda binaries can take a long time, and therefore you can download them from down below, this ZIP should be extracted into Redist/compiled-runtimes. If you want to recompile, I include instructions below how to recompile.

Step 1: Clone the repository:

git clone https://github.com/shaltielshmid/TorchSharp.FlashAttention.git
cd TorchSharp.FlashAttention
git lfs pull

The next steps varies whether you are on Windows or Linux.

On windows:

  • Open Visual Studio (I tested it with VS2022)
  • Build all in Release
    • The first time you do this, this can take some time because it is downloading and extracting the libtorch bindings, and cloning the FlashAttention repository and applying a patch to make it exportable for C++.
  • The built binaries can be found in: ...\TorchSharp.FlashAttention\TorchSharp.FlashAttention\obj\Debug\net6.0\

On linux:

  • Navigate to the directory
  • Run:
    export CUDA_PATH_V12_1=/path/to/cuda/toolkit/root
    dotnet build TorchSharp.FlashAttention -c Release
    

Compiling Flash Attention cuda binaries

  • Make sure you run a build at least once, so that the source code is retrieved.
  • Navigate to TorchSharp.FlashAttention\Redist\flash-attn-2.5.5
  • Run either compile_flash.bat (for windows) or bash compile_flash.sh (for linux).

Pre-compiled CUDA Binaries for FlashAttention

Acknowledgments

This project is a C# wrapper around the original Flash Attention implementation by Dao-AILab. Immense gratitude goes to the creators and contributors of Flash Attention for their innovative work and for providing guidelines to encourage community-driven adaptations and extensions.

Contributions

Contributions to TorchSharp.FlashAttention are warmly welcomed. Whether it's adding new features, improving documentation, or fixing bugs, your input is valuable. Please feel free to submit pull requests or open issues to discuss potential changes or report bugs.

Product Compatible and additional computed target framework versions.
.NET net6.0 is compatible.  net6.0-android was computed.  net6.0-ios was computed.  net6.0-maccatalyst was computed.  net6.0-macos was computed.  net6.0-tvos was computed.  net6.0-windows was computed.  net7.0 was computed.  net7.0-android was computed.  net7.0-ios was computed.  net7.0-maccatalyst was computed.  net7.0-macos was computed.  net7.0-tvos was computed.  net7.0-windows was computed.  net8.0 was computed.  net8.0-android was computed.  net8.0-browser was computed.  net8.0-ios was computed.  net8.0-maccatalyst was computed.  net8.0-macos was computed.  net8.0-tvos was computed.  net8.0-windows was computed. 
Compatible target framework(s)
Included target framework(s) (in package)
Learn more about Target Frameworks and .NET Standard.
  • net6.0

    • No dependencies.

NuGet packages

This package is not used by any NuGet packages.

GitHub repositories

This package is not used by any popular GitHub repositories.

Version Downloads Last updated
0.4.0 451 12/24/2024
0.3.0 467 11/8/2024
0.2.2 1,647 9/16/2024
0.2.1 1,813 3/14/2024
0.2.0 804 2/28/2024
0.1.0 702 2/10/2024

0.4.0:
     - Updated to `TorchSharp >= 0.105.0` with `LibTorch 2.5.1`
     0.3.0:
     - Updated to `TorchSharp >= 0.103.0` with `LibTorch 2.4.0`