import _collections_abc import _weakrefset import abc import collections import contextlib import copy import copyreg import dataclasses import enum import functools import importlib import inspect import linecache import logging import multiprocessing import operator import os import posixpath import random import re import selectors import signal import tempfile import threading import tokenize import traceback import types import typing import unittest import weakref import torch import torch._inductor.test_operators try: import torch._prims # isort: split # TODO: Hack to unblock simultaneous landing changes. Fix after https://github.com/pytorch/pytorch/pull/81088 lands import torch._prims.utils import torch._prims.wrappers import torch._refs import torch._refs.nn import torch._refs.nn.functional import torch._refs.special HAS_PRIMS_REFS = True except ImportError: HAS_PRIMS_REFS = False from . import comptime, config, external_utils """ A note on skipfiles: Dynamo consults this file to determine whether code should be compiled or skipped. A skip applies at the frame boundary, meaning dynamo either triggers a graph break at the beginning of the frame or attempts to trace the whole frame. When skipping a frame, recursively called frames are still traced by dynamo unless also skipped. Skipfiles (skipped at the file level instead of function level) still apply on a frame-by-frame boundary as dynamo traces, but apply to all functions in that file. @skip is a helper decorator that can be applied to your function to cause it to be included here. """ def _strip_init_py(s): return re.sub(r"__init__.py$", "", s) def _module_dir(m: types.ModuleType): return _strip_init_py(m.__file__) SKIP_DIRS = [ # torch.* _module_dir(torch), # torchdynamo.* os.path.dirname(__file__) + "/", "", ] + [ # skip some standard libs _module_dir(m) for m in ( abc, collections, contextlib, copy, copyreg, dataclasses, enum, functools, importlib, inspect, linecache, logging, multiprocessing, operator, os, posixpath, random, re, selectors, signal, tempfile, threading, tokenize, traceback, types, typing, unittest, weakref, _collections_abc, _weakrefset, ) ] FILENAME_ALLOWLIST = { torch.nn.Sequential.__init__.__code__.co_filename, torch.set_rng_state.__code__.co_filename, torch._inductor.test_operators.__file__, # These are dynamo files! external_utils.__file__, comptime.__file__, # Want to inline these helpers } # Include optimizer code for tracing FILENAME_ALLOWLIST |= { inspect.getfile(obj) for obj in torch.optim.__dict__.values() if inspect.isclass(obj) } FILENAME_ALLOWLIST |= {torch.optim._functional.__file__} if HAS_PRIMS_REFS: FILENAME_ALLOWLIST |= { torch._prims.__file__, torch._prims.utils.__file__, torch._prims.wrappers.__file__, torch._refs.__file__, torch._refs.special.__file__, torch._refs.nn.functional.__file__, } SKIP_DIRS_RE = None def _recompile_re(): global SKIP_DIRS_RE SKIP_DIRS_RE = re.compile(f"^({'|'.join(map(re.escape, SKIP_DIRS))})") def add(import_name: str): if isinstance(import_name, types.ModuleType): return add(import_name.__name__) assert isinstance(import_name, str) module_spec = importlib.util.find_spec(import_name) if not module_spec: return origin = module_spec.origin if origin is None: return global SKIP_DIRS_RE SKIP_DIRS.append(_strip_init_py(origin)) _recompile_re() def check(filename, allow_torch=False): """Should skip this file?""" if filename is None: return True if filename in FILENAME_ALLOWLIST: return False if allow_torch and is_torch(filename): return False return bool(SKIP_DIRS_RE.match(filename)) # skip common third party libs for _name in ( "functorch", "intel_extension_for_pytorch", "networkx", "numpy", "omegaconf", "onnx", "onnxruntime", "onnx_tf", "pandas", "sklearn", "tabulate", "tensorflow", "tensorrt", "torch2trt", "tqdm", "tree", "tvm", "fx2trt_oss", "xarray", ): add(_name) _recompile_re() def is_torch_inline_allowed(filename): return any( filename.startswith(_module_dir(mod)) for mod in config.skipfiles_inline_module_allowlist ) @functools.lru_cache(None) def dynamo_dir(): import torch._dynamo return _module_dir(torch._dynamo) def is_torch(filename): if filename.startswith(dynamo_dir()): return False return filename.startswith(_module_dir(torch))