Source code for maltoolbox.model

"""
MAL-Toolbox Model Module
"""

from __future__ import annotations
from dataclasses import dataclass, field
import json
import logging
from typing import TYPE_CHECKING
import math

from .file_utils import (
    load_dict_from_json_file,
    load_dict_from_yaml_file,
    save_dict_to_file
)

from . import __version__
from .exceptions import ModelException

if TYPE_CHECKING:
    from typing import Any, Optional
    from .language import (
        LanguageGraph,
        LanguageGraphAsset,
    )

logger = logging.getLogger(__name__)

[docs] @dataclass class AttackerAttachment: """Used to attach attackers to attack step entry points of assets""" id: Optional[int] = None name: Optional[str] = None entry_points: list[tuple[ModelAsset, list[str]]] = \ field(default_factory=lambda: [])
[docs] def get_entry_point_tuple( self, asset: ModelAsset ) -> Optional[tuple[ModelAsset, list[str]]]: """Return an entry point tuple of an AttackerAttachment matching the asset provided. Arguments: asset - the asset to add entry point to Return: The entry point tuple containing the asset and the list of attack steps if the asset has any entry points defined for this attacker attachemnt. None, otherwise. """ return next((ep_tuple for ep_tuple in self.entry_points if ep_tuple[0] == asset), None)
[docs] def add_entry_point( self, asset: ModelAsset, attackstep_name: str): """Add an entry point to an AttackerAttachment self.entry_points contain tuples, first element of each tuple is an asset, second element is a list of attack step names that are entry points for the attacker. Arguments: asset - the asset to add the entry point to attackstep_name - the name of the attack step to add as an entry point """ logger.debug( f'Add entry point "{attackstep_name}" on asset "{asset.name}" ' f'to AttackerAttachment "{self.name}".' ) # Get the entry point tuple for the asset if it already exists entry_point_tuple = self.get_entry_point_tuple(asset) if entry_point_tuple: if attackstep_name not in entry_point_tuple[1]: # If it exists and does not already have the attack step, # add it entry_point_tuple[1].append(attackstep_name) else: logger.info( f'Entry point "{attackstep_name}" on asset "{asset.name}"' f' already existed for AttackerAttachment "{self.name}".' ) else: # Otherwise, create the entry point tuple and the initial entry # point self.entry_points.append((asset, [attackstep_name]))
[docs] def remove_entry_point( self, asset: ModelAsset, attackstep_name: str): """Remove an entry point from an AttackerAttachment if it exists Arguments: asset - the asset to remove the entry point from """ logger.debug( f'Remove entry point "{attackstep_name}" on asset "{asset.name}" ' f'from AttackerAttachment "{self.name}".' ) # Get the entry point tuple for the asset if it exists entry_point_tuple = self.get_entry_point_tuple(asset) if entry_point_tuple: if attackstep_name in entry_point_tuple[1]: # If it exists and not already has the attack step, add it entry_point_tuple[1].remove(attackstep_name) else: logger.warning( f'Failed to find entry point "{attackstep_name}" on ' f'asset "{asset.name}" for AttackerAttachment ' f'"{self.name}". Nothing to remove.' ) if not entry_point_tuple[1]: self.entry_points.remove(entry_point_tuple) else: logger.warning( f'Failed to find entry points on asset "{asset.name}" ' f'for AttackerAttachment "{self.name}". Nothing to remove.' )
[docs] class Model(): """An implementation of a MAL language model containing assets""" next_id: int = 0 def __repr__(self) -> str: return f'Model(name: "{self.name}", language: {self.lang_graph})' def __init__( self, name: str, lang_graph: LanguageGraph, mt_version: str = __version__ ): self.name = name self.assets: dict[int, ModelAsset] = {} self._name_to_asset:dict[str, ModelAsset] = {} # optimization self.attackers: list[AttackerAttachment] = [] self.lang_graph = lang_graph self.maltoolbox_version: str = mt_version
[docs] def add_asset( self, asset_type: str, name: Optional[str] = None, asset_id: Optional[int] = None, defenses: Optional[dict[str, float]] = None, extras: Optional[dict] = None, allow_duplicate_names: bool = True ) -> ModelAsset: """ Create an asset based on the provided parameters and add it to the model. Arguments: asset_type - string containing the asset type name name - string containing the asset name. If not provided the concatenated asset type and id will be used as a name. asset_id - id to assign to this asset, usually from an instance model file. If not provided the id will be set to the next highest id available. defeses - dictionary of defense values extras - dictionary of extras allow_duplicate_name - allow duplicate names to be used. If allowed and a duplicate is encountered the name will be appended with the id. Return: The newly created asset. """ # Set asset ID and check for duplicates asset_id = asset_id or self.next_id if asset_id in self.assets: raise ValueError(f'Asset index {asset_id} already in use.') self.next_id = max(asset_id + 1, self.next_id) if not name: name = asset_type + ':' + str(asset_id) else: if name in self._name_to_asset: if allow_duplicate_names: name = name + ':' + str(asset_id) else: raise ValueError( f'Asset name {name} is a duplicate' ' and we do not allow duplicates.' ) lg_asset = self.lang_graph.assets[asset_type] asset = ModelAsset( name = name, asset_id = asset_id, lg_asset = lg_asset, defenses = defenses, extras = extras) logger.debug( 'Add "%s"(%d) to model "%s".', name, asset_id, self.name ) self.assets[asset_id] = asset self._name_to_asset[name] = asset return asset
[docs] def remove_attacker(self, attacker: AttackerAttachment) -> None: """Remove attacker""" self.attackers.remove(attacker)
[docs] def remove_asset(self, asset: ModelAsset) -> None: """Remove an asset from the model. Arguments: asset - the asset to remove """ logger.debug( 'Remove "%s"(%d) from model "%s".', asset.name, asset.id, self.name ) if asset.id not in self.assets: raise LookupError( f'Asset "{asset.name}"({asset.id}) is not part' f' of model"{self.name}".' ) # First remove all of the associated assets # We can not remove from the dict while iterating over it # so we first have to copy the keys and then remove those assets associated_fieldnames = dict(asset.associated_assets) for fieldname, assoc_assets in associated_fieldnames.items(): asset.remove_associated_assets(fieldname, assoc_assets) # Also remove all of the entry points for attacker in self.attackers: entry_point_tuple = attacker.get_entry_point_tuple(asset) if entry_point_tuple: attacker.entry_points.remove(entry_point_tuple) del self.assets[asset.id] del self._name_to_asset[asset.name]
[docs] def add_attacker( self, attacker: AttackerAttachment, attacker_id: Optional[int] = None ) -> None: """Add an attacker to the model. Arguments: attacker - the attacker to add attacker_id - optional id for the attacker """ if attacker_id is not None: attacker.id = attacker_id else: attacker.id = self.next_id self.next_id = max(attacker.id + 1, self.next_id) if not hasattr(attacker, 'name') or not attacker.name: attacker.name = 'Attacker:' + str(attacker.id) self.attackers.append(attacker)
[docs] def get_asset_by_id( self, asset_id: int ) -> Optional[ModelAsset]: """ Find an asset in the model based on its id. Arguments: asset_id - the id of the asset we are looking for Return: An asset matching the id if it exists in the model. """ logger.debug( 'Get asset with id %d from model "%s".', asset_id, self.name ) return self.assets.get(asset_id, None)
[docs] def get_asset_by_name( self, asset_name: str ) -> Optional[ModelAsset]: """ Find an asset in the model based on its name. Arguments: asset_name - the name of the asset we are looking for Return: An asset matching the name if it exists in the model. """ logger.debug( 'Get asset with name "%s" from model "%s".', asset_name, self.name ) return self._name_to_asset.get(asset_name, None)
[docs] def get_attacker_by_id( self, attacker_id: int ) -> Optional[AttackerAttachment]: """ Find an attacker in the model based on its id. Arguments: attacker_id - the id of the attacker we are looking for Return: An attacker matching the id if it exists in the model. """ logger.debug( 'Get attacker with id %d from model "%s".', attacker_id, self.name ) return next( (attacker for attacker in self.attackers if attacker.id == attacker_id), None )
[docs] def attacker_to_dict( self, attacker: AttackerAttachment ) -> tuple[Optional[int], dict]: """Get dictionary representation of the attacker. Arguments: attacker - attacker to get dictionary representation of """ logger.debug('Translating %s to dictionary.', attacker.name) attacker_dict: dict[str, Any] = { 'name': attacker.name, 'entry_points': {}, } for (asset, attack_steps) in attacker.entry_points: attacker_dict['entry_points'][asset.name] = { 'asset_id': asset.id, 'attack_steps' : attack_steps } return (attacker.id, attacker_dict)
def _to_dict(self) -> dict: """Get dictionary representation of the model.""" logger.debug('Translating model to dict.') contents: dict[str, Any] = { 'metadata': {}, 'assets': {}, 'attackers' : {} } contents['metadata'] = { 'name': self.name, 'langVersion': self.lang_graph.metadata['version'], 'langID': self.lang_graph.metadata['id'], 'malVersion': '0.1.0-SNAPSHOT', 'MAL-Toolbox Version': __version__, 'info': 'Created by the mal-toolbox model python module.' } logger.debug('Translating assets to dictionary.') for asset in self.assets.values(): contents['assets'].update(asset._to_dict()) logger.debug('Translating attackers to dictionary.') for attacker in self.attackers: (attacker_id, attacker_dict) = self.attacker_to_dict(attacker) contents['attackers'][attacker_id] = attacker_dict return contents
[docs] def save_to_file(self, filename: str) -> None: """Save to json/yml depending on extension""" logger.debug('Save instance model to file "%s".', filename) return save_dict_to_file(filename, self._to_dict())
@classmethod def _from_dict( cls, serialized_object: dict, lang_graph: LanguageGraph, ) -> Model: """Create a model from dict representation Arguments: serialized_object - Model in dict format lang_graph - """ maltoolbox_version = serialized_object['metadata']['MAL Toolbox Version'] \ if 'MAL Toolbox Version' in serialized_object['metadata'] \ else __version__ model = Model( serialized_object['metadata']['name'], lang_graph, mt_version = maltoolbox_version) # Reconstruct the assets for asset_id, asset_dict in serialized_object['assets'].items(): if logger.isEnabledFor(logging.DEBUG): # Avoid running json.dumps when not in debug logger.debug( "Loading asset:\n%s", json.dumps(asset_dict, indent=2) ) # Allow defining an asset via type only. asset_dict = ( asset_dict if isinstance(asset_dict, dict) else { 'type': asset_dict, 'name': f"{asset_dict}:{asset_id}" } ) model.add_asset( asset_type = asset_dict['type'], name = asset_dict['name'], defenses = {defense: float(value) for defense, value in \ asset_dict.get('defenses', {}).items()}, extras = asset_dict.get('extras', {}), asset_id = int(asset_id)) # Reconstruct the association links for asset_id, asset_dict in serialized_object['assets'].items(): asset = model.assets[int(asset_id)] assoc_assets_dict = asset_dict['associated_assets'].items() for fieldname, assoc_assets in assoc_assets_dict: asset.add_associated_assets( fieldname, {model.assets[int(assoc_asset_id)] for assoc_asset_id in assoc_assets} ) # Reconstruct the attackers if 'attackers' in serialized_object: attackers_info = serialized_object['attackers'] for attacker_id in attackers_info: attacker = AttackerAttachment(name = attackers_info[attacker_id]['name']) for asset_name, entry_points_dict in \ attackers_info[attacker_id]['entry_points'].items(): target_asset = model.get_asset_by_id( int(entry_points_dict['asset_id']) ) if target_asset is None: raise LookupError( 'Asset "%s"(%d) is not part of model "%s".' % ( asset_name, entry_points_dict['asset_id'], model.name) ) attacker.entry_points.append( ( target_asset, entry_points_dict['attack_steps'] ) ) model.add_attacker(attacker, attacker_id = int(attacker_id)) return model
[docs] @classmethod def load_from_file( cls, filename: str, lang_graph: LanguageGraph, ) -> Model: """Create from json or yaml file depending on file extension""" logger.debug('Load instance model from file "%s".', filename) serialized_model = None if filename.endswith(('.yml', '.yaml')): serialized_model = load_dict_from_yaml_file(filename) elif filename.endswith('.json'): serialized_model = load_dict_from_json_file(filename) else: raise ValueError('Unknown file extension, expected json/yml/yaml') try: return cls._from_dict(serialized_model, lang_graph) except Exception as e: raise ModelException( "Could not load model. It might be of an older version. " "Try to upgrade it with 'maltoolbox upgrade-model'" ) from e
[docs] class ModelAsset: def __init__( self, name: str, asset_id: int, lg_asset: LanguageGraphAsset, defenses: Optional[dict[str, float]] = None, extras: Optional[dict] = None ): self.name: str = name self._id: int = asset_id self.lg_asset: LanguageGraphAsset = lg_asset self.type = self.lg_asset.name self.defenses: dict[str, float] = defenses or {} self.extras: dict = extras or {} self._associated_assets: dict[str, set[ModelAsset]] = {} self.attack_step_nodes: list = [] for step in self.lg_asset.attack_steps.values(): if step.type == 'defense' and step.name not in self.defenses: self.defenses[step.name] = 1.0 if step.ttc and \ step.ttc['name'] == 'Enabled' else 0.0 def _to_dict(self): """Get dictionary representation of the asset.""" logger.debug( 'Translating "%s"(%d) to dictionary.', self.name, self.id) asset_dict: dict[str, Any] = { 'name': self.name, 'type': self.type, 'defenses': {}, 'associated_assets': {} } # Only add non-default values for defenses to improve legibility of # the model format for defense, defense_value in self.defenses.items(): lg_step = self.lg_asset.attack_steps[defense] default_defval = 1.0 if lg_step.ttc and \ lg_step.ttc['name'] == 'Enabled' else 0.0 if defense_value != default_defval: asset_dict['defenses'][defense] = defense_value for fieldname, assets in self.associated_assets.items(): asset_dict['associated_assets'][fieldname] = {asset.id: asset.name for asset in assets} if len(asset_dict['defenses']) == 0: # Do not include an empty defenses dictionary del asset_dict['defenses'] if self.extras != {}: # Add optional metadata to dict asset_dict['extras'] = self.extras return {self.id: asset_dict} def __repr__(self): return (f'ModelAsset(name: "{self.name}", id: {self.id}, ' f'type: {self.type})')
[docs] def validate_associated_assets( self, fieldname: str, assets_to_add: set[ModelAsset] ): """ Validate an association we want to add (through `fieldname`) is valid with the assets given in param `assets_to_add`: - fieldname is valid for the asset type of this ModelAsset - type of `assets_to_add` is valid for the association - no more assets than 'field.maximum' are added to the field Raises: LookupError - fieldname can not be found for this ModelAsset ValueError - there will be too many assets in the field if we add this association TypeError - if the asset type of `assets_to_add` is not valid """ # Validate that the field name is allowed for this asset type if fieldname not in self.lg_asset.associations: accepted_fieldnames = list(self.lg_asset.associations.keys()) raise LookupError( f"Fieldname '{fieldname}' is not an accepted association " f"fieldname from asset type {self.lg_asset.name}. " f"Did you mean one of {accepted_fieldnames}?" ) lg_assoc = self.lg_asset.associations[fieldname] assoc_field = lg_assoc.get_field(fieldname) # Validate that the asset to add association to is of correct type for asset_to_add in assets_to_add: if not asset_to_add.lg_asset.is_subasset_of(assoc_field.asset): raise TypeError( f"Asset '{asset_to_add.name}' of type " f"'{asset_to_add.type}' can not be added to association " f"'{self.name}.{fieldname}'. Expected type of " f"'{fieldname}' is {assoc_field.asset.name}." ) # Validate that there will not be too many assets in field assets_in_field_before = self.associated_assets.get(fieldname, set()) assets_in_field_after = assets_in_field_before | set(assets_to_add) max_assets_in_field = assoc_field.maximum or math.inf if len(assets_in_field_after) > max_assets_in_field: raise ValueError( f"You can have maximum {assoc_field.maximum} " f"assets for association field {fieldname}" )
[docs] def add_associated_assets(self, fieldname: str, assets: set[ModelAsset]): """ Add the assets provided as a parameter to the set of associated assets dictionary entry corresponding to the given fieldname. """ lg_assoc = self.lg_asset.associations[fieldname] other_fieldname = lg_assoc.get_opposite_fieldname(fieldname) # Validation from both sides self.validate_associated_assets(fieldname, assets) for asset in assets: asset.validate_associated_assets(other_fieldname, {self}) # Add the associated assets to this asset's dictionary self._associated_assets.setdefault( fieldname, set() ).update(assets) # Add this asset to the associated assets' corresponding dictionaries for asset in assets: asset._associated_assets.setdefault( other_fieldname, set() ).add(self)
[docs] def remove_associated_assets( self, fieldname: str, assets: set[ModelAsset]): """ Remove the assets provided as a parameter from the set of associated assets dictionary entry corresponding to the fieldname parameter. """ # Remove this asset from its associated assets' dictionaries lg_assoc = self.lg_asset.associations[fieldname] other_fieldname = lg_assoc.get_opposite_fieldname(fieldname) for asset in assets: asset._associated_assets[other_fieldname].remove(self) if len(asset._associated_assets[other_fieldname]) == 0: del asset._associated_assets[other_fieldname] # Remove associated assets from this asset self._associated_assets[fieldname] -= set(assets) if len(self._associated_assets[fieldname]) == 0: del self._associated_assets[fieldname]
@property def associated_assets(self): return self._associated_assets @property def id(self): return self._id