Skip to content
Snippets Groups Projects
Commit 6bd00067 authored by Christopher Spinrath's avatar Christopher Spinrath
Browse files

Evaluate variables lazily

parent b20868ad
No related branches found
No related tags found
No related merge requests found
......@@ -4,22 +4,23 @@ import itertools
import pathlib
import subprocess
from abc import ABC, abstractmethod
import click
import yaml
class Variable:
def __init__(self, name, values):
class Variable(ABC):
def __init__(self, name):
self._name = name
self._values = values
@property
def name(self):
return self._name
@property
def values(self):
return self._values
@abstractmethod
def evaluate(self, variable_mapping):
pass
@classmethod
def new_scalar(cls, name, value):
......@@ -29,34 +30,45 @@ class Variable:
def new_multi_valued(cls, name, values):
return cls(name, values)
@classmethod
def new(cls, name, value_config):
if isinstance(value_config, list):
return cls.new_multi_valued(name, value_config)
elif not isinstance(value_config, dict): # scalar value
return cls.new_scalar(name, value_config)
var_type = value_config["type"]
assert var_type in ["file", "range"], "Unsupported variable type."
class SimpleVariable(Variable):
def __init__(self, name, values):
super().__init__(name)
self._values = values
if var_type == "file":
directory = pathlib.Path(value_config["directory"]).expanduser()
assert directory.exists() and directory.is_dir(), f"No such directory: {directory}"
def evaluate(self, variable_mapping):
def substitute(v):
if isinstance(v, str):
return v.format(**variable_mapping)
else:
return v
values = sorted(list(directory.iterdir()))
return [substitute(v) for v in self._values]
if value_config.get("basename", False):
values = [v.name for v in values]
if var_type == "range":
assert "max" in value_config, f"Missing \"max\" value for range variable \"{name}\""
upper = value_config["max"]
lower = value_config.get("min", 0)
step = value_config.get("step", 1)
class RangeVariable(SimpleVariable):
def __init__(self, name, lower, upper, step):
super().__init__(name, list(range(lower, upper + 1, step)))
class FileVariable(Variable):
def __init__(self, name, directory_str, basename_only):
super().__init__(name)
self._directory_str = directory_str
self._basename_only = basename_only
def evaluate(self, variable_mapping):
directory = pathlib.Path(self._directory_str.format(**variable_mapping))
directory = directory.expanduser()
values = list(range(lower, upper + 1, step))
assert directory.exists() and directory.is_dir(), f"No such directory: {directory}"
return cls.new_multi_valued(name, values)
values = sorted(list(directory.iterdir()))
if self._basename_only:
values = [v.name for v in values]
return values
class Task:
......@@ -173,6 +185,33 @@ class Config:
def tasks(self):
return self._tasks
@staticmethod
def parse_var_definition(name, value_config):
if isinstance(value_config, list):
return SimpleVariable.new_multi_valued(name, value_config)
elif not isinstance(value_config, dict): # scalar value
return SimpleVariable.new_scalar(name, value_config)
var_type = value_config["type"]
assert var_type in ["file", "range"], "Unsupported variable type."
if var_type == "file":
return FileVariable(
name,
value_config["directory"],
value_config.get("basename", False),
)
if var_type == "range":
assert "max" in value_config, f"Missing \"max\" value for range variable \"{name}\""
return RangeVariable(
name,
value_config.get("min", 0),
value_config["max"],
value_config.get("step", 1),
)
@classmethod
def from_file(cls, filepath: pathlib.Path):
with filepath.open('r') as file:
......@@ -180,7 +219,9 @@ class Config:
name = raw_config.get("name", str(filepath))
repetitions = raw_config.get("repetitions", 1)
variables = [Variable.new(k, v) for k, v in raw_config.get("variables", {}).items()]
variables = [
cls.parse_var_definition(k, v) for k, v in raw_config.get("variables", {}).items()
]
tasks = [Task.from_dict(t) for t in raw_config.get("tasks", [])]
return cls(name, repetitions, variables, tasks)
......@@ -200,7 +241,7 @@ def main(config_file, global_timeout, dry_run):
print(f"> and {c.repetitions} repetitions")
print()
variable_map = {v.name: v.values for v in c.variables}
variable_map = {v.name: v.evaluate({}) for v in c.variables}
value_combinations = itertools.product(*variable_map.values())
run_values = [dict(zip(variable_map.keys(), vals)) for vals in value_combinations]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment