import os
[docs]class Registry:
'''The class for registry of modules.'''
mapping = {
"attacks": {},
"models": {},
"lr_schedulers": {},
"transforms": {},
"paths": {},
}
[docs] @classmethod
def register_attack(cls, name=None, force=False):
r"""Register an attack method to registry with key 'name'
Args:
name (str): Key with which the attacker will be registered.
force (bool): Whether to register when the name has already existed in registry.
"""
def wrap(attack):
registerd_name = attack.__name__ if name is None else name
if registerd_name in cls.mapping["attacks"] and not force:
raise KeyError(
"Name '{}' already registered for {}.".format(
registerd_name, cls.mapping["attacks"][registerd_name]
)
)
cls.mapping["attacks"][registerd_name] = attack
return attack
return wrap
[docs] @classmethod
def register_model(cls, name=None, force=False):
r"""Register a model to registry with key 'name'
Args:
name (str): Key with which the attacker will be registered.
force (bool): Whether to register when the name has already existed in registry.
"""
def wrap(model):
registerd_name = model.__name__ if name is None else name
if registerd_name in cls.mapping["models"] and not force:
raise KeyError(
"Name '{}' already registered for {}.".format(
registerd_name, cls.mapping["models"][registerd_name]
)
)
cls.mapping["models"][registerd_name] = model
return model
return wrap
[docs] @classmethod
def register_lr_scheduler(cls, name=None, force=False):
r"""Register a learning rate scheduler to registry with key 'name'
Args:
name (str): Key with which the attacker will be registered.
force (bool): Whether to register when the name has already existed in registry.
"""
def wrap(lr_scheduler):
registerd_name = lr_scheduler.__name__ if name is None else name
if registerd_name in cls.mapping["lr_schedulers"] and not force:
raise KeyError(
"Name '{}' already registered for {}.".format(
registerd_name, cls.mapping["lr_schedulers"][registerd_name]
)
)
cls.mapping["lr_schedulers"][registerd_name] = lr_scheduler
return lr_scheduler
return wrap
[docs] @classmethod
def register_path(cls, name, path):
r"""Register a path to registry with key 'name'
Args:
name (str): Key with which the path will be registered.
"""
assert isinstance(path, str), "All path must be str."
if name in cls.mapping["paths"]:
return
if not os.path.exists(path):
os.makedirs(path)
cls.mapping["paths"][name] = path
[docs] @classmethod
def get_attack(cls, name):
'''Get a attack method by given name.'''
if cls.mapping["attacks"].get(name, None):
return cls.mapping["attacks"].get(name)
raise KeyError(f'{name} is not registered!')
[docs] @classmethod
def get_model(cls, name):
'''Get a model object by given name.'''
if cls.mapping["models"].get(name, None):
return cls.mapping["models"].get(name)
raise KeyError(f'{name} is not registered!')
[docs] @classmethod
def get_lr_scheduler(cls, name):
'''Get a lr scheduler object by given name.'''
if cls.mapping["lr_schedulers"].get(name, None):
return cls.mapping["lr_schedulers"].get(name)
raise KeyError(f'{name} is not registered!')
[docs] @classmethod
def get_path(cls, name):
'''Get a path by given name.'''
if cls.mapping["paths"].get(name, None):
return cls.mapping["paths"].get(name)
raise KeyError(f'{name} is not registered!')
[docs] @classmethod
def list_attacks(cls):
'''List all attack methods registered.'''
return sorted(cls.mapping["attacks"].keys())
[docs] @classmethod
def list_models(cls):
'''List all model classes registered.'''
return sorted(cls.mapping["models"].keys())
[docs] @classmethod
def list_lr_schedulers(cls):
'''List all lr schedulers registered.'''
return sorted(cls.mapping["models"].keys())
registry = Registry()