from __future__ import annotations
import logging
import re
import typing as ty
from pathlib import Path
import shutil
import attrs
import attrs.filters
from attrs.converters import default_if_none
from pydra.utils.hash import hash_single, bytes_repr_mapping_contents
from fileformats.text import Plain as PlainText
from arcana.core.exceptions import (
ArcanaDataMatchError,
ArcanaLicenseNotFoundError,
ArcanaNameError,
ArcanaUsageError,
ArcanaWrongDataSpaceError,
)
from ..column import DataColumn, DataSink, DataSource
from .. import store as datastore
from ..tree import DataTree
from ..space import DataSpace
from .metadata import DatasetMetadata, metadata_converter
if ty.TYPE_CHECKING: # pragma: no cover
from arcana.core.deploy.image.components import License
from arcana.core.data.entry import DataEntry
from arcana.core.analysis.base import Analysis
from arcana.core.analysis.pipeline import Pipeline
logger = logging.getLogger("arcana")
[docs]
@attrs.define(kw_only=True)
class Dataset:
"""
A representation of a "dataset", the complete collection of data
(file-sets and fields) to be used in an analysis.
Parameters
----------
id : str
The dataset id/path that uniquely identifies the dataset within the
store it is stored (e.g. FS directory path or project ID)
store : Repository
The store the dataset is stored into. Can be the local file
system by providing a MockRemote repo.
space: DataSpace
The space of the dataset. See https://arcana.readthedocs.io/en/latest/data_model.html#spaces)
for a description
id_patterns : dict[str, str]
Patterns for inferring IDs of rows not explicitly present in the hierarchy of
the data tree. See ``DataStore.infer_ids()`` for syntax
hierarchy : list[str]
The data frequencies that are explicitly present in the data tree.
For example, if a MockRemote dataset (i.e. directory) has
two layer hierarchy of sub-directories, the first layer of
sub-directories labelled by unique subject ID, and the second directory
layer labelled by study time-point then the hierarchy would be
['subject', 'timepoint']
Alternatively, in some stores (e.g. XNAT) the second layer in the
hierarchy may be named with session ID that is unique across the project,
in which case the layer dimensions would instead be
['subject', 'session']
In such cases, if there are multiple timepoints, the timepoint ID of the
session will need to be extracted using the `id_patterns` argument.
Alternatively, the hierarchy could be organised such that the tree
first splits on longitudinal time-points, then a second directory layer
labelled by member ID, with the final layer containing sessions of
matched members labelled by their groups (e.g. test & control):
['timepoint', 'member', 'group']
Note that the combination of layers in the hierarchy must span the
space defined in the DataSpace enum, i.e. the "bitwise or" of the
layer values of the hierarchy must be 1 across all bits
(e.g. 'session': 0b111).
metadata : dict or DatasetMetadata
Generic metadata associated with the dataset, e.g. authors, funding sources, etc...
include : list[tuple[DataSpace, str or ty.List[str]]]
The IDs to be included in the dataset per row_frequency. E.g. can be
used to limit the subject IDs in a project to the sub-set that passed
QC. If a row_frequency is omitted or its value is None, then all available
will be used
exclude : list[tuple[DataSpace, str or ty.List[str]]]
The IDs to be excluded in the dataset per row_frequency. E.g. can be
used to exclude specific subjects that failed QC. If a row_frequency is
omitted or its value is None, then all available will be used
name : str
The name of the dataset as saved in the store under
columns : list[tuple[str, DataSource or DataSink]
The sources and sinks to be initially added to the dataset (columns are
explicitly added when workflows are applied to the dataset).
pipelines : dict[str, pydra.Workflow]
Pipelines that have been applied to the dataset to generate sink
access_args: ty.Dict[str, Any]
Repository specific args used to control the way the dataset is accessed
"""
LICENSES_PATH = (
"LICENSES" # The resource that project-specifc licenses are expected
)
id: str = attrs.field(converter=str, metadata={"asdict": False})
store: datastore.DataStore = attrs.field()
space: ty.Type[DataSpace] = attrs.field()
id_patterns: ty.Dict[str, str] = attrs.field(
factory=dict, converter=default_if_none(factory=dict)
)
hierarchy: ty.List[str] = attrs.field(converter=list)
metadata: DatasetMetadata = attrs.field(
factory=DatasetMetadata,
converter=metadata_converter,
repr=False,
)
include: ty.Dict[str, ty.Union[ty.List[str], str]] = attrs.field(
factory=dict, converter=default_if_none(factory=dict), repr=False
)
exclude: ty.Dict[str, ty.Union[ty.List[str], str]] = attrs.field(
factory=dict, converter=default_if_none(factory=dict), repr=False
)
name: str = attrs.field(default="")
columns: ty.Dict[str, DataColumn] = attrs.field(
factory=dict, converter=default_if_none(factory=dict), repr=False
)
pipelines: ty.Dict[str, Pipeline] = attrs.field(
factory=dict, converter=default_if_none(factory=dict), repr=False
)
analyses: ty.Dict[str, Analysis] = attrs.field(
factory=dict, converter=default_if_none(factory=dict), repr=False
)
tree: DataTree = attrs.field(factory=DataTree, init=False, repr=False, eq=False)
def __attrs_post_init__(self):
self.tree.dataset = self
# Set reference to pipeline in columns and pipelines
for column in self.columns.values():
column.dataset = self
for pipeline in self.pipelines.values():
pipeline.dataset = self
@name.validator
def name_validator(self, _, name: str):
if name and not name.isidentifier():
raise ArcanaUsageError(
f"Name provided to dataset, '{name}' should be a valid Python identifier, "
"i.e. contain only numbers, letters and underscores and not start with a "
"number"
)
if name == self.store.EMPTY_DATASET_NAME:
raise ArcanaUsageError(
f"'{self.store.EMPTY_DATASET_NAME}' is a reserved name for datasets as it is used to "
"in place of the empty dataset name in situations where '' can't be used"
)
@columns.validator
def columns_validator(self, _, columns):
wrong_freq = [
m for m in columns.values() if not isinstance(m.row_frequency, self.space)
]
if wrong_freq:
raise ArcanaUsageError(
f"Data hierarchy of {wrong_freq} column specs do(es) not match "
f"that of dataset {self.space}"
)
@include.validator
def include_validator(self, _, include: ty.Dict[str, ty.Union[str, ty.List[str]]]):
valid = set(str(f) for f in self.space)
freqs = set(include)
unrecognised = freqs - valid
if unrecognised:
raise ArcanaUsageError(
f"Unrecognised frequencies in 'include' dictionary provided to {self}: "
+ ", ".join(unrecognised)
)
self._validate_criteria(include, "inclusion")
@exclude.validator
def exclude_validator(self, _, exclude: ty.Dict[str, ty.Union[str, ty.List[str]]]):
valid = set(self.hierarchy)
freqs = set(exclude)
unrecognised = freqs - valid
if unrecognised:
raise ArcanaUsageError(
f"Unrecognised frequencies in 'exclude' dictionary provided to {self}, "
"only frequencies present in the dataset hierarchy are allowed: "
+ ", ".join(unrecognised)
)
self._validate_criteria(exclude, "exclusion")
def _validate_criteria(self, criteria, type_):
for freq, criterion in criteria.items():
try:
re.compile(criterion)
except Exception:
if not isinstance(criterion, list) or any(
not isinstance(x, str) for x in criterion
):
raise ArcanaUsageError(
f"Unrecognised {type_} criterion for '{freq}' provided to {self}, "
f"{criterion}, should either be a list of ID strings or a valid "
"regular expression"
)
@hierarchy.validator
def hierarchy_validator(self, _, hierarchy):
not_valid = [f for f in hierarchy if str(f) not in self.space.__members__]
if not_valid:
raise ArcanaWrongDataSpaceError(
f"hierarchy items {not_valid} are not part of the {self.space} data space"
)
# Check that all data frequencies are "covered" by the hierarchy and
# each subsequent
covered = self.space(0)
for i, layer_str in enumerate(hierarchy):
layer = self.space[str(layer_str)]
diff = (layer ^ covered) & layer
if not diff:
raise ArcanaUsageError(
f"{layer} does not add any additional basis layers to "
f"previous layers {hierarchy[i:]}"
)
covered |= layer
if covered != max(self.space):
raise ArcanaUsageError(
"The data hierarchy ['"
+ "', '".join(hierarchy)
+ "'] does not cover the following basis frequencies ['"
+ "', '".join(str(m) for m in (covered ^ max(self.space)).span())
+ f"'] the '{self.space.__module__}.{self.space.__name__}' data space"
)
# if missing_axes:
# raise ArcanaDataTreeConstructionError(
# "Leaf node at %s is missing explicit IDs for the following axes, %s"
# ", they will be set to None, noting that an error will be raised if there "
# " multiple nodes for this session. In that case, set 'id-patterns' on the "
# "dataset to extract the missing axis IDs from composite IDs or row "
# "metadata",
# tree_path,
# missing_axes,
# )
# for m in missing_axes:
# ids[m] = None
@id_patterns.validator
def id_patterns_validator(self, _, id_patterns):
non_valid_keys = [f for f in id_patterns if f not in self.space.__members__]
if non_valid_keys:
raise ArcanaWrongDataSpaceError(
f"Keys for the id_patterns dictionary {non_valid_keys} are not part "
f"of the {self.space} data space"
)
for key, expr in id_patterns.items():
groups = list(re.compile(expr).groupindex)
non_valid_groups = [f for f in groups if f not in self.space.__members__]
if non_valid_groups:
raise ArcanaWrongDataSpaceError(
f"Groups in the {key} id_patterns expression {non_valid_groups} "
f"are not part of the {self.space} data space"
)
def save(self, name=""):
self.store.save_dataset(self, name=name)
@classmethod
def load(
cls,
id: str,
store: datastore.DataStore = None,
name: ty.Optional[str] = "",
**kwargs,
):
"""Loads a dataset from an store/ID/name string, as used in the CLI
Parameters
----------
id: str
either the ID of a dataset if `store` keyword arg is provided or a
"dataset ID string" in the format <store-nickname>//<dataset-id>[@<dataset-name>]
store: DataStore, optional
the store to load the dataset. If not provided the provided ID
is interpreted as an ID string
name: str, optional
the name of the dataset within the project/directory
(e.g. 'test', 'training'). Used to specify a subset of data rows
to work with, within a greater project
**kwargs
keyword arguments parsed to the data store load
Returns
-------
Dataset
the loaded dataset"""
if store is None:
store_name, id, parsed_name = cls.parse_id_str(id)
store = datastore.DataStore.load(store_name, **kwargs)
if not name and parsed_name:
name = parsed_name
return store.load_dataset(id, name=name)
@property
def root_freq(self):
return self.space(0)
@property
def root_dir(self):
return Path(self.id)
@property
def leaf_freq(self):
return max(self.space)
@property
def prov(self):
return {
"id": self.id,
"store": self.store.prov,
"ids": {str(freq): tuple(ids) for freq, ids in self.rows.items()},
}
@property
def root(self):
"""Lazily loads the data tree from the store on demand and return root
Returns
-------
DataRow
The root row of the data tree
"""
# Build the tree cache and return the tree root. Note that if there is a
# "with <this-dataset>.tree" statement further up the call stack then the
# cache won't be broken down until the highest cache statement exits
with self.tree:
return self.tree.root
@property
def locator(self):
if self.store.name is None:
raise Exception(
f"Must save store {self.store} first before accessing locator for "
f"{self}"
)
locator = f"{self.store.name}//{self.id}"
if self.name:
locator += f"@{self.name}"
return locator
[docs]
def add_source(
self,
name: str,
datatype: type,
path: ty.Optional[str] = None,
row_frequency: ty.Optional[str] = None,
overwrite: bool = False,
**kwargs,
) -> DataSource:
"""Specify a data source in the dataset, which can then be referenced
when connecting workflow inputs.
Parameters
----------
name : str
The name used to reference the dataset "column" for the
source
datatype : type
The file-format (for file-sets) or datatype (for fields)
that the source will be stored in within the dataset
path : str, default `name`
The location of the source within the dataset
row_frequency : DataSpace, default self.leaf_freq
The row_frequency of the source within the dataset
overwrite : bool
Whether to overwrite existing columns
**kwargs : ty.Dict[str, Any]
Additional kwargs to pass to DataSource.__init__
"""
row_frequency = self.parse_frequency(row_frequency)
if path is None:
path = name
source = DataSource(
name=name,
datatype=datatype,
path=path,
row_frequency=row_frequency,
dataset=self,
**kwargs,
)
self._add_column(name, source, overwrite)
return source
[docs]
def add_sink(
self,
name: str,
datatype: type,
row_frequency: ty.Optional[str] = None,
overwrite: bool = False,
**kwargs,
) -> DataSink:
"""Specify a data source in the dataset, which can then be referenced
when connecting workflow inputs.
Parameters
----------
name : str
The name used to reference the dataset "column" for the
sink
datatype : type
The file-format (for file-sets) or datatype (for fields)
that the sink will be stored in within the dataset
path : str, optional
Specify a particular for the sink within the dataset, defaults to the column
name within the dataset derivatives directory of the store
row_frequency : str, optional
The row_frequency of the sink within the dataset, by default the leaf
frequency of the data tree
overwrite : bool
Whether to overwrite an existing sink
"""
row_frequency = self.parse_frequency(row_frequency)
sink = DataSink(
name=name,
datatype=datatype,
row_frequency=row_frequency,
dataset=self,
**kwargs,
)
self._add_column(name, sink, overwrite)
return sink
def _add_column(self, name: str, spec, overwrite):
if name in self.columns:
if overwrite:
logger.info(
f"Overwriting {self.columns[name]} with {spec} in " f"{self}"
)
else:
raise ArcanaNameError(
name,
f"Name clash attempting to add {spec} to {self} "
f"with {self.columns[name]}. Use 'overwrite' option "
"if this is desired",
)
self.columns[name] = spec
def row(self, frequency=None, id=attrs.NOTHING, **id_kwargs):
"""Returns the row associated with the given frequency and ids dict
Parameters
----------
frequency : DataSpace or str
The frequency of the row
id : str or Tuple[str], optional
The ID of the row to
**id_kwargs : Dict[str, str]
Alternatively to providing `id`, ID corresponding to the row to
return passed as kwargs
Returns
-------
DataRow
The selected data row
Raises
------
ArcanaUsageError
Raised when attempting to use IDs with the frequency associated
with the root row
ArcanaNameError
If there is no row corresponding to the given ids
"""
with self.tree:
# Parse str to frequency enums
if not frequency:
if id not in (None, attrs.NOTHING):
raise ArcanaUsageError(f"Root rows don't have any IDs ({id})")
return self.root
frequency = self.parse_frequency(frequency)
if id is not attrs.NOTHING:
if id_kwargs:
raise ArcanaUsageError(
f"ID ({id}) and id_kwargs ({id_kwargs}) cannot be both "
f"provided to `row` method of {self}"
)
try:
return self.root.children[frequency][id]
except KeyError as e:
if isinstance(id, tuple) and len(id) == self.space.ndim:
# Expand ID tuple to see if it is an expansion of the ID axes
# instead of a direct label for the row
id_kwargs = {a: i for a, i in zip(self.space.axes(), id)}
else:
raise ArcanaNameError(
id,
f"{id} not present in data tree "
f"({list(self.row_ids(frequency))})",
) from e
elif not id_kwargs:
raise ArcanaUsageError(
f"Neither ID nor id_kwargs cannot were provided `row` method of {self}"
)
# Iterate through the tree to find the row (i.e. tree node) matching the
# provided IDs
row = self.root
cum_freq = self.space(0)
for freq, id in id_kwargs.items():
cum_freq |= freq
try:
row = row.children[cum_freq][id]
except KeyError as e:
raise ArcanaNameError(
id, f"{id} ({freq}) not a child row of {row}"
) from e
if cum_freq != frequency:
raise ArcanaUsageError(
f"Cumulative frequency of ID kwargs {id_kwargs} ({cum_freq}) does not "
"match that of row"
)
return row
def rows(self, frequency=None, ids=None):
"""Return all the IDs in the dataset for a given frequency
Parameters
----------
frequency : DataSpace, optional
The "frequency" of the rows, e.g. per-session, per-subject, defaults to
leaf rows
ids : Sequence[str or Tuple[str]]
The i
Returns
-------
Sequence[DataRow]
The sequence of the data row within the dataset
"""
if frequency is None:
frequency = max(self.space) # "leaf" nodes of the data tree
else:
frequency = self.parse_frequency(frequency)
with self.tree:
if frequency == self.root_freq:
return [self.root]
rows = self.root.children[frequency].values()
if ids is not None:
rows = (n for n in rows if n.id in set(ids))
return rows
def row_ids(self, frequency: ty.Optional[str] = None):
"""Return all the IDs in the dataset for a given row_frequency
Parameters
----------
frequency : str
The "frequency" of the rows to return the IDs for, e.g. per-session, per-subject...
Returns
-------
Sequence[str]
The IDs of the rows
"""
if frequency is None:
frequency = max(self.space) # "leaf" nodes of the data tree
else:
frequency = self.parse_frequency(frequency)
if frequency == self.root_freq:
return [None]
with self.tree:
try:
return self.root.children[frequency].keys()
except KeyError:
return ()
def __getitem__(self, name):
"""Return all data items across the dataset for a given source or sink
Parameters
----------
name : str
Name of the column to return
Returns
-------
DataColumn
the column object
"""
return self.columns[name]
def apply_pipeline(
self,
name,
workflow,
inputs,
outputs,
row_frequency=None,
overwrite=False,
converter_args=None,
):
"""Connect a Pydra workflow as a pipeline of the dataset
Parameters
----------
name : str
name of the pipeline
workflow : pydra.Workflow
pydra workflow to connect to the dataset as a pipeline
inputs : list[arcana.core.analysis.pipeline.Input or tuple[str, str, type] or tuple[str, str]]
List of inputs to the pipeline (see `arcana.core.analysis.pipeline.Pipeline.PipelineInput`)
outputs : list[arcana.core.analysis.pipeline.Output or tuple[str, str, type] or tuple[str, str]]
List of outputs of the pipeline (see `arcana.core.analysis.pipeline.Pipeline.PipelineOutput`)
row_frequency : str, optional
the frequency of the data rows the pipeline will be executed over, i.e.
will it be run once per-session, per-subject or per whole dataset,
by default the highest row frequency (e.g. per-session for Clinical)
overwrite : bool, optional
overwrite connections to previously connected sinks, by default False
converter_args : dict[str, dict]
keyword arguments passed on to the converter to control how the
conversion is performed.
Returns
-------
Pipeline
the pipeline added to the dataset
Raises
------
ArcanaUsageError
if overwrite is false and
"""
from arcana.core.analysis.pipeline import Pipeline
row_frequency = self.parse_frequency(row_frequency)
# def parsed_conns(lst, conn_type):
# parsed = []
# for spec in lst:
# if isinstance(spec, conn_type):
# parsed.append(spec)
# elif len(spec) == 3:
# parsed.append(conn_type(*spec))
# else:
# col_name, field = spec
# parsed.append(conn_type(col_name, field, self[col_name].datatype))
# return parsed
pipeline = Pipeline(
name=name,
dataset=self,
row_frequency=row_frequency,
workflow=workflow,
inputs=inputs,
outputs=outputs,
converter_args=converter_args,
)
for outpt in pipeline.outputs:
sink = self[outpt.name]
if sink.pipeline_name is not None:
if overwrite:
logger.info(
f"Overwriting pipeline of sink '{outpt.name}' "
f"{sink.pipeline_name} with {name}"
)
else:
raise ArcanaUsageError(
f"Attempting to overwrite pipeline of '{outpt.name}' "
f"sink ({sink.pipeline_name}) with {name}. Use "
f"'overwrite' option if this is desired"
)
sink.pipeline_name = pipeline.name
self.pipelines[name] = pipeline
return pipeline
def apply(self, analysis):
self.analyses[analysis.name] = analysis
def derive(self, *sink_names, ids=None, cache_dir=None, **kwargs):
"""Generate derivatives from the workflows
Parameters
----------
*sink_names : Iterable[str]
Names of the columns corresponding to the items to derive
ids : Iterable[str]
The IDs of the data rows in each column to derive
cache_dir
Returns
-------
Sequence[List[DataType]]
The derived columns
"""
from arcana.core.analysis.pipeline import Pipeline
sinks = [self[s] for s in set(sink_names)]
for pipeline, _ in Pipeline.stack(*sinks):
# Execute pipelines in stack
# FIXME: Should combine the pipelines into a single workflow and
# dilate the IDs that need to be run when summarising over different
# data axes
with self.tree:
pipeline(ids=ids, cache_dir=cache_dir)(**kwargs)
def parse_frequency(self, freq):
"""Parses the data row_frequency, converting from string if necessary and
checks it matches the dimensions of the dataset"""
if freq is None:
return max(self.space)
try:
if isinstance(freq, str):
freq = self.space[freq]
elif not isinstance(freq, self.space):
raise KeyError
except KeyError as e:
raise ArcanaWrongDataSpaceError(
f"{freq} is not a valid dimension for {self} " f"({self.space})"
) from e
return freq
@classmethod
def _sink_path(cls, workflow_name, sink_name):
return f"{workflow_name}/{sink_name}"
@classmethod
def parse_id_str(cls, id):
parts = id.split("//")
if len(parts) == 1: # No store definition, default to the `DirTree` store
store_name = "dirtree"
else:
store_name, id = parts
parts = id.split("@")
if len(parts) == 1:
name = ""
else:
id, name = parts
return store_name, id, name
def download_licenses(self, licenses: ty.List[License]):
"""Install licenses from project-specific location in data store and
install them at the destination location
Parameters
----------
licenses : list[License]
the list of licenses stored in the dataset or in a site-wide location that
need to be downloaded to the local file-system before a pipeline is run
Raises
------
ArcanaLicenseNotFoundError
raised if the license of the given name isn't present in the project-specific
location to retrieve
"""
from arcana.core.deploy.image.components import License
site_licenses_dataset = self.store.site_licenses_dataset()
for lic in licenses:
missing = False
try:
license_file = self.get_license_file(lic.name)
except ArcanaDataMatchError:
if site_licenses_dataset is not None:
try:
license_file = self.get_license_file(
lic.name, dataset=site_licenses_dataset
)
except ArcanaDataMatchError:
missing = True
else:
missing = True
if missing:
msg = (
f"Did not find a license corresponding to '{lic.name}' at "
f"{License.column_path(lic.name)} in {self}"
)
if site_licenses_dataset:
msg += f" or {site_licenses_dataset}"
raise ArcanaLicenseNotFoundError(
lic.name,
msg,
)
shutil.copyfile(license_file, lic.destination)
def install_license(self, name: str, source_file: PlainText):
"""Store project-specific license in dataset
Parameters
----------
name : str
name of the license to install
source_file : PlainText
the license file to install
"""
from arcana.core.deploy.image.components import License
try:
entry = self._get_license_entry(name)
except ArcanaDataMatchError:
entry = self.store.create_entry(
License.column_path(name), PlainText, self.root
)
self.store.put(PlainText(source_file), entry)
def _get_license_entry(self, name, dataset=None) -> DataEntry:
from arcana.core.deploy.image.components import License
if dataset is None:
dataset = self
column = DataSink(
name=f"{name}_license",
datatype=PlainText,
row_frequency=self.root_freq,
dataset=dataset,
path=License.column_path(name),
)
return column.match_entry(dataset.root)
def get_license_file(self, name, dataset=None) -> PlainText:
return PlainText(self._get_license_entry(name, dataset).item)
def infer_ids(
self, ids: ty.Dict[str, str], metadata: ty.Dict[str, ty.Dict[str, str]]
):
return self.store.infer_ids(
ids=ids, id_patterns=self.id_patterns, metadata=metadata
)
def __bytes_repr__(self, cache):
"""For Pydra input hashing"""
yield f"{type(self).__module__}.{type(self).__name__}(".encode()
yield self.id.encode()
yield bytes(hash_single(self.store, cache))
yield bytes(hash_single(self.space, cache))
yield bytes(hash_single(self.include, cache))
yield bytes(hash_single(self.exclude, cache))
yield self.name.encode()
yield from bytes_repr_mapping_contents(self.columns, cache)
@attrs.define
class SplitDataset:
"""A dataset created by combining multiple datasets into a conglomerate
Parameters
----------
"""
source_dataset: Dataset = attrs.field()
sink_dataset: Dataset = attrs.field()