Skip to content
Snippets Groups Projects
Commit 6bd00067 authored by Christopher Spinrath's avatar Christopher Spinrath
Browse files

Evaluate variables lazily

parent b20868ad
No related branches found
No related tags found
No related merge requests found
...@@ -4,22 +4,23 @@ import itertools ...@@ -4,22 +4,23 @@ import itertools
import pathlib import pathlib
import subprocess import subprocess
from abc import ABC, abstractmethod
import click import click
import yaml import yaml
class Variable: class Variable(ABC):
def __init__(self, name, values): def __init__(self, name):
self._name = name self._name = name
self._values = values
@property @property
def name(self): def name(self):
return self._name return self._name
@property @abstractmethod
def values(self): def evaluate(self, variable_mapping):
return self._values pass
@classmethod @classmethod
def new_scalar(cls, name, value): def new_scalar(cls, name, value):
...@@ -29,34 +30,45 @@ class Variable: ...@@ -29,34 +30,45 @@ class Variable:
def new_multi_valued(cls, name, values): def new_multi_valued(cls, name, values):
return cls(name, values) return cls(name, values)
@classmethod
def new(cls, name, value_config):
if isinstance(value_config, list):
return cls.new_multi_valued(name, value_config)
elif not isinstance(value_config, dict): # scalar value
return cls.new_scalar(name, value_config)
var_type = value_config["type"] class SimpleVariable(Variable):
assert var_type in ["file", "range"], "Unsupported variable type." def __init__(self, name, values):
super().__init__(name)
self._values = values
if var_type == "file": def evaluate(self, variable_mapping):
directory = pathlib.Path(value_config["directory"]).expanduser() def substitute(v):
assert directory.exists() and directory.is_dir(), f"No such directory: {directory}" if isinstance(v, str):
return v.format(**variable_mapping)
else:
return v
values = sorted(list(directory.iterdir())) return [substitute(v) for v in self._values]
if value_config.get("basename", False):
values = [v.name for v in values]
if var_type == "range": class RangeVariable(SimpleVariable):
assert "max" in value_config, f"Missing \"max\" value for range variable \"{name}\"" def __init__(self, name, lower, upper, step):
upper = value_config["max"] super().__init__(name, list(range(lower, upper + 1, step)))
lower = value_config.get("min", 0)
step = value_config.get("step", 1)
class FileVariable(Variable):
def __init__(self, name, directory_str, basename_only):
super().__init__(name)
self._directory_str = directory_str
self._basename_only = basename_only
def evaluate(self, variable_mapping):
directory = pathlib.Path(self._directory_str.format(**variable_mapping))
directory = directory.expanduser()
values = list(range(lower, upper + 1, step)) assert directory.exists() and directory.is_dir(), f"No such directory: {directory}"
return cls.new_multi_valued(name, values) values = sorted(list(directory.iterdir()))
if self._basename_only:
values = [v.name for v in values]
return values
class Task: class Task:
...@@ -173,6 +185,33 @@ class Config: ...@@ -173,6 +185,33 @@ class Config:
def tasks(self): def tasks(self):
return self._tasks return self._tasks
@staticmethod
def parse_var_definition(name, value_config):
if isinstance(value_config, list):
return SimpleVariable.new_multi_valued(name, value_config)
elif not isinstance(value_config, dict): # scalar value
return SimpleVariable.new_scalar(name, value_config)
var_type = value_config["type"]
assert var_type in ["file", "range"], "Unsupported variable type."
if var_type == "file":
return FileVariable(
name,
value_config["directory"],
value_config.get("basename", False),
)
if var_type == "range":
assert "max" in value_config, f"Missing \"max\" value for range variable \"{name}\""
return RangeVariable(
name,
value_config.get("min", 0),
value_config["max"],
value_config.get("step", 1),
)
@classmethod @classmethod
def from_file(cls, filepath: pathlib.Path): def from_file(cls, filepath: pathlib.Path):
with filepath.open('r') as file: with filepath.open('r') as file:
...@@ -180,7 +219,9 @@ class Config: ...@@ -180,7 +219,9 @@ class Config:
name = raw_config.get("name", str(filepath)) name = raw_config.get("name", str(filepath))
repetitions = raw_config.get("repetitions", 1) repetitions = raw_config.get("repetitions", 1)
variables = [Variable.new(k, v) for k, v in raw_config.get("variables", {}).items()] variables = [
cls.parse_var_definition(k, v) for k, v in raw_config.get("variables", {}).items()
]
tasks = [Task.from_dict(t) for t in raw_config.get("tasks", [])] tasks = [Task.from_dict(t) for t in raw_config.get("tasks", [])]
return cls(name, repetitions, variables, tasks) return cls(name, repetitions, variables, tasks)
...@@ -200,7 +241,7 @@ def main(config_file, global_timeout, dry_run): ...@@ -200,7 +241,7 @@ def main(config_file, global_timeout, dry_run):
print(f"> and {c.repetitions} repetitions") print(f"> and {c.repetitions} repetitions")
print() print()
variable_map = {v.name: v.values for v in c.variables} variable_map = {v.name: v.evaluate({}) for v in c.variables}
value_combinations = itertools.product(*variable_map.values()) value_combinations = itertools.product(*variable_map.values())
run_values = [dict(zip(variable_map.keys(), vals)) for vals in value_combinations] run_values = [dict(zip(variable_map.keys(), vals)) for vals in value_combinations]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment