| | |
| |
|
| | import sys |
| | import warnings |
| | import os |
| | import re |
| | import ast |
| | import glob |
| | import shutil |
| | from pathlib import Path |
| | from packaging.version import parse, Version |
| | import platform |
| |
|
| | from setuptools import setup, find_packages |
| | import subprocess |
| |
|
| | import urllib.request |
| | import urllib.error |
| | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel |
| |
|
| | import torch |
| | from torch.utils.cpp_extension import ( |
| | BuildExtension, |
| | CppExtension, |
| | CUDAExtension, |
| | CUDA_HOME, |
| | ROCM_HOME, |
| | IS_HIP_EXTENSION, |
| | ) |
| |
|
| |
|
| | with open("README.md", "r", encoding="utf-8") as fh: |
| | long_description = fh.read() |
| |
|
| |
|
| | |
| | this_dir = os.path.dirname(os.path.abspath(__file__)) |
| |
|
| | BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto") |
| |
|
| | if BUILD_TARGET == "auto": |
| | if IS_HIP_EXTENSION: |
| | IS_ROCM = True |
| | else: |
| | IS_ROCM = False |
| | else: |
| | if BUILD_TARGET == "cuda": |
| | IS_ROCM = False |
| | elif BUILD_TARGET == "rocm": |
| | IS_ROCM = True |
| |
|
| | PACKAGE_NAME = "flash_attn" |
| |
|
| | BASE_WHEEL_URL = ( |
| | "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" |
| | ) |
| |
|
| | |
| | |
| | FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" |
| | SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" |
| | |
| | FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" |
| |
|
| |
|
| | def get_platform(): |
| | """ |
| | Returns the platform name as used in wheel filenames. |
| | """ |
| | if sys.platform.startswith("linux"): |
| | return f'linux_{platform.uname().machine}' |
| | elif sys.platform == "darwin": |
| | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) |
| | return f"macosx_{mac_version}_x86_64" |
| | elif sys.platform == "win32": |
| | return "win_amd64" |
| | else: |
| | raise ValueError("Unsupported platform: {}".format(sys.platform)) |
| |
|
| |
|
| | def get_cuda_bare_metal_version(cuda_dir): |
| | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) |
| | output = raw_output.split() |
| | release_idx = output.index("release") + 1 |
| | bare_metal_version = parse(output[release_idx].split(",")[0]) |
| |
|
| | return raw_output, bare_metal_version |
| |
|
| |
|
| | def check_if_cuda_home_none(global_option: str) -> None: |
| | if CUDA_HOME is not None: |
| | return |
| | |
| | |
| | warnings.warn( |
| | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " |
| | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " |
| | "only images whose names contain 'devel' will provide nvcc." |
| | ) |
| |
|
| |
|
| | def check_if_rocm_home_none(global_option: str) -> None: |
| | if ROCM_HOME is not None: |
| | return |
| | |
| | |
| | warnings.warn( |
| | f"{global_option} was requested, but hipcc was not found." |
| | ) |
| |
|
| |
|
| | def append_nvcc_threads(nvcc_extra_args): |
| | nvcc_threads = os.getenv("NVCC_THREADS") or "4" |
| | return nvcc_extra_args + ["--threads", nvcc_threads] |
| |
|
| |
|
| | def rename_cpp_to_cu(cpp_files): |
| | for entry in cpp_files: |
| | shutil.copy(entry, os.path.splitext(entry)[0] + ".cu") |
| |
|
| |
|
| | def validate_and_update_archs(archs): |
| | |
| | allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"] |
| |
|
| | |
| | assert all( |
| | arch in allowed_archs for arch in archs |
| | ), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention" |
| |
|
| |
|
| | cmdclass = {} |
| | ext_modules = [] |
| |
|
| | |
| | |
| | if IS_ROCM: |
| | subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"]) |
| | else: |
| | subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) |
| |
|
| | if not SKIP_CUDA_BUILD and not IS_ROCM: |
| | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) |
| | TORCH_MAJOR = int(torch.__version__.split(".")[0]) |
| | TORCH_MINOR = int(torch.__version__.split(".")[1]) |
| |
|
| | |
| | |
| | generator_flag = [] |
| | torch_dir = torch.__path__[0] |
| | if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): |
| | generator_flag = ["-DOLD_GENERATOR_PATH"] |
| |
|
| | check_if_cuda_home_none("flash_attn") |
| | |
| | cc_flag = [] |
| | if CUDA_HOME is not None: |
| | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) |
| | if bare_metal_version < Version("11.6"): |
| | raise RuntimeError( |
| | "FlashAttention is only supported on CUDA 11.6 and above. " |
| | "Note: make sure nvcc has a supported version by running nvcc -V." |
| | ) |
| | |
| | |
| | cc_flag.append("-gencode") |
| | cc_flag.append("arch=compute_80,code=sm_80") |
| | if CUDA_HOME is not None: |
| | if bare_metal_version >= Version("11.8"): |
| | cc_flag.append("-gencode") |
| | cc_flag.append("arch=compute_90,code=sm_90") |
| |
|
| | |
| | |
| | |
| | if FORCE_CXX11_ABI: |
| | torch._C._GLIBCXX_USE_CXX11_ABI = True |
| | ext_modules.append( |
| | CUDAExtension( |
| | name="flash_attn_2_cuda", |
| | sources=[ |
| | "csrc/flash_attn/flash_api.cpp", |
| | "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu", |
| | "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu", |
| | ], |
| | extra_compile_args={ |
| | "cxx": ["-O3", "-std=c++17"] + generator_flag, |
| | "nvcc": append_nvcc_threads( |
| | [ |
| | "-O3", |
| | "-std=c++17", |
| | "-U__CUDA_NO_HALF_OPERATORS__", |
| | "-U__CUDA_NO_HALF_CONVERSIONS__", |
| | "-U__CUDA_NO_HALF2_OPERATORS__", |
| | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", |
| | "--expt-relaxed-constexpr", |
| | "--expt-extended-lambda", |
| | "--use_fast_math", |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ] |
| | + generator_flag |
| | + cc_flag |
| | ), |
| | }, |
| | include_dirs=[ |
| | Path(this_dir) / "csrc" / "flash_attn", |
| | Path(this_dir) / "csrc" / "flash_attn" / "src", |
| | Path(this_dir) / "csrc" / "cutlass" / "include", |
| | ], |
| | ) |
| | ) |
| | elif not SKIP_CUDA_BUILD and IS_ROCM: |
| | ck_dir = "csrc/composable_kernel" |
| |
|
| | |
| | if not os.path.exists("./build"): |
| | os.makedirs("build") |
| |
|
| | os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2") |
| | os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2") |
| |
|
| | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) |
| | TORCH_MAJOR = int(torch.__version__.split(".")[0]) |
| | TORCH_MINOR = int(torch.__version__.split(".")[1]) |
| |
|
| | |
| | |
| | generator_flag = [] |
| | torch_dir = torch.__path__[0] |
| | if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): |
| | generator_flag = ["-DOLD_GENERATOR_PATH"] |
| |
|
| | check_if_rocm_home_none("flash_attn") |
| | cc_flag = [] |
| |
|
| | archs = os.getenv("GPU_ARCHS", "native").split(";") |
| | validate_and_update_archs(archs) |
| |
|
| | cc_flag = [f"--offload-arch={arch}" for arch in archs] |
| |
|
| | |
| | |
| | |
| | if FORCE_CXX11_ABI: |
| | torch._C._GLIBCXX_USE_CXX11_ABI = True |
| |
|
| | sources = ["csrc/flash_attn_ck/flash_api.cpp", |
| | "csrc/flash_attn_ck/mha_bwd.cpp", |
| | "csrc/flash_attn_ck/mha_fwd.cpp", |
| | "csrc/flash_attn_ck/mha_varlen_bwd.cpp", |
| | "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob( |
| | f"build/fmha_*wd*.cpp" |
| | ) |
| |
|
| | rename_cpp_to_cu(sources) |
| |
|
| | renamed_sources = ["csrc/flash_attn_ck/flash_api.cu", |
| | "csrc/flash_attn_ck/mha_bwd.cu", |
| | "csrc/flash_attn_ck/mha_fwd.cu", |
| | "csrc/flash_attn_ck/mha_varlen_bwd.cu", |
| | "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") |
| | extra_compile_args = { |
| | "cxx": ["-O3", "-std=c++17"] + generator_flag, |
| | "nvcc": |
| | [ |
| | "-O3","-std=c++17", |
| | "-mllvm", "-enable-post-misched=0", |
| | "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", |
| | "-fgpu-flush-denormals-to-zero", |
| | "-DCK_ENABLE_BF16", |
| | "-DCK_ENABLE_BF8", |
| | "-DCK_ENABLE_FP16", |
| | "-DCK_ENABLE_FP32", |
| | "-DCK_ENABLE_FP64", |
| | "-DCK_ENABLE_FP8", |
| | "-DCK_ENABLE_INT8", |
| | "-DCK_USE_XDL", |
| | "-DUSE_PROF_API=1", |
| | "-D__HIP_PLATFORM_HCC__=1", |
| | |
| | ] |
| | + generator_flag |
| | + cc_flag |
| | , |
| | } |
| |
|
| | include_dirs = [ |
| | Path(this_dir) / "csrc" / "composable_kernel" / "include", |
| | Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include", |
| | Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha", |
| | ] |
| |
|
| | ext_modules.append( |
| | CUDAExtension( |
| | name="flash_attn_2_cuda", |
| | sources=renamed_sources, |
| | extra_compile_args=extra_compile_args, |
| | include_dirs=include_dirs, |
| | ) |
| | ) |
| |
|
| |
|
| | def get_package_version(): |
| | with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f: |
| | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) |
| | public_version = ast.literal_eval(version_match.group(1)) |
| | local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") |
| | if local_version: |
| | return f"{public_version}+{local_version}" |
| | else: |
| | return str(public_version) |
| |
|
| |
|
| | def get_wheel_url(): |
| | torch_version_raw = parse(torch.__version__) |
| | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" |
| | platform_name = get_platform() |
| | flash_version = get_package_version() |
| | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" |
| | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() |
| |
|
| | if IS_ROCM: |
| | torch_hip_version = parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+')) |
| | hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}" |
| | wheel_filename = f"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" |
| | else: |
| | |
| | |
| | |
| | torch_cuda_version = parse(torch.version.cuda) |
| | |
| | |
| | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") |
| | |
| | cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" |
| |
|
| | |
| | wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" |
| |
|
| | wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename) |
| |
|
| | return wheel_url, wheel_filename |
| |
|
| |
|
| | class CachedWheelsCommand(_bdist_wheel): |
| | """ |
| | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot |
| | find an existing wheel (which is currently the case for all flash attention installs). We use |
| | the environment parameters to detect whether there is already a pre-built version of a compatible |
| | wheel available and short-circuits the standard full build pipeline. |
| | """ |
| |
|
| | def run(self): |
| | if FORCE_BUILD: |
| | return super().run() |
| |
|
| | wheel_url, wheel_filename = get_wheel_url() |
| | print("Guessing wheel URL: ", wheel_url) |
| | try: |
| | urllib.request.urlretrieve(wheel_url, wheel_filename) |
| |
|
| | |
| | |
| | |
| | if not os.path.exists(self.dist_dir): |
| | os.makedirs(self.dist_dir) |
| |
|
| | impl_tag, abi_tag, plat_tag = self.get_tag() |
| | archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" |
| |
|
| | wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") |
| | print("Raw wheel path", wheel_path) |
| | os.rename(wheel_filename, wheel_path) |
| | except (urllib.error.HTTPError, urllib.error.URLError): |
| | print("Precompiled wheel not found. Building from source...") |
| | |
| | super().run() |
| |
|
| |
|
| | class NinjaBuildExtension(BuildExtension): |
| | def __init__(self, *args, **kwargs) -> None: |
| | |
| | if not os.environ.get("MAX_JOBS"): |
| | import psutil |
| |
|
| | |
| | max_num_jobs_cores = max(1, os.cpu_count() // 2) |
| |
|
| | |
| | free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) |
| | max_num_jobs_memory = int(free_memory_gb / 9) |
| |
|
| | |
| | max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory)) |
| | os.environ["MAX_JOBS"] = str(max_jobs) |
| |
|
| | super().__init__(*args, **kwargs) |
| |
|
| |
|
| | setup( |
| | name=PACKAGE_NAME, |
| | version=get_package_version(), |
| | packages=find_packages( |
| | exclude=( |
| | "build", |
| | "csrc", |
| | "include", |
| | "tests", |
| | "dist", |
| | "docs", |
| | "benchmarks", |
| | "flash_attn.egg-info", |
| | ) |
| | ), |
| | author="Tri Dao", |
| | author_email="tri@tridao.me", |
| | description="Flash Attention: Fast and Memory-Efficient Exact Attention", |
| | long_description=long_description, |
| | long_description_content_type="text/markdown", |
| | url="https://github.com/Dao-AILab/flash-attention", |
| | classifiers=[ |
| | "Programming Language :: Python :: 3", |
| | "License :: OSI Approved :: BSD License", |
| | "Operating System :: Unix", |
| | ], |
| | ext_modules=ext_modules, |
| | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension} |
| | if ext_modules |
| | else { |
| | "bdist_wheel": CachedWheelsCommand, |
| | }, |
| | python_requires=">=3.8", |
| | install_requires=[ |
| | "torch", |
| | "einops", |
| | ], |
| | setup_requires=[ |
| | "packaging", |
| | "psutil", |
| | "ninja", |
| | ], |
| | ) |