import heapq import torch import triton import triton._C.libtriton.triton as _triton from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype): ''' return compute throughput in TOPS ''' total_warps = num_ctas * min(num_warps, 4) triton.compiler.init_cuda_utils() num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device) return tflops def get_simd_tflops(backend, device, num_ctas, num_warps, dtype): ''' return compute throughput in TOPS ''' total_warps = num_ctas * min(num_warps, 4) num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device) return tflops def get_tflops(backend, device, num_ctas, num_warps, dtype): capability = torch.cuda.get_device_capability(device) if capability[0] < 8 and dtype == torch.float32: return get_simd_tflops(backend, device, num_ctas, num_warps, dtype) return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype) def estimate_matmul_time( # backend, device, num_warps, num_stages, A, B, C, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, debug=False, **kwargs ): ''' return estimated running time in ms = max(compute, loading) + store ''' backend = _triton.runtime.backend.CUDA device = torch.cuda.current_device() dtype = A.dtype dtsize = A.element_size() num_cta_m = triton.cdiv(M, BLOCK_M) num_cta_n = triton.cdiv(N, BLOCK_N) num_cta_k = SPLIT_K num_ctas = num_cta_m * num_cta_n * num_cta_k # If the input is smaller than the block size M, N = max(M, BLOCK_M), max(N, BLOCK_N) # time to compute total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS tput = get_tflops(backend, device, num_ctas, num_warps, dtype) compute_ms = total_ops / tput # time to load data num_sm = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] active_cta_ratio = min(1, num_ctas / num_sm) active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) # assume 80% of (following) loads are in L2 cache load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1)) load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1) load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1)) load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1) # total total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) # loading time in ms load_ms = total_dram / dram_bw + total_l2 / l2_bw # estimate storing time store_bw = dram_bw * 0.6 # :o store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB if SPLIT_K == 1: store_ms = store_c_dram / store_bw else: reduce_bw = store_bw store_ms = store_c_dram / reduce_bw # c.zero_() zero_ms = M * N * 2 / (1024 * 1024) / store_bw store_ms += zero_ms total_time_ms = max(compute_ms, load_ms) + store_ms if debug: print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, ' f'loading time: {load_ms}ms, store time: {store_ms}ms, ' f'Activate CTAs: {active_cta_ratio*100}%') return total_time_ms def early_config_prune(configs, named_args): device = torch.cuda.current_device() capability = torch.cuda.get_device_capability() # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages dtsize = named_args['A'].element_size() dtype = named_args['A'].dtype # 1. make sure we have enough smem pruned_configs = [] for config in configs: kw = config.kwargs BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \ kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages # TODO: move to `cuda_utils` submodule triton.compiler.init_cuda_utils() max_shared_memory = triton.compiler.cuda_utils.get_device_properties(device)["max_shared_mem"] required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize if required_shared_memory <= max_shared_memory: pruned_configs.append(config) configs = pruned_configs # Some dtypes do not allow atomic_add if dtype not in [torch.float16, torch.float32]: configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1] # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) configs_map = {} for config in configs: kw = config.kwargs BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \ kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) if key in configs_map: configs_map[key].append((config, num_stages)) else: configs_map[key] = [(config, num_stages)] pruned_configs = [] for k, v in configs_map.items(): BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k if capability[0] >= 8: # compute cycles (only works for ampere GPUs) mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) mma_cycles = mmas / min(4, num_warps) * 8 ldgsts_latency = 300 # Does this matter? optimal_num_stages = ldgsts_latency / mma_cycles # nearest stages, prefer large #stages nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) for n in nearest: pruned_configs.append(n[0]) else: # Volta & Turing only supports num_stages <= 2 random_config = v[0][0] random_config.num_stages = 2 pruned_configs.append(random_config) return pruned_configs