From 6bd00067a90a120fc37d2976cfd015cd695ecaa2 Mon Sep 17 00:00:00 2001
From: Christopher Spinrath <christopher.spinrath@univ-grenoble-alpes.fr>
Date: Wed, 30 Apr 2025 14:26:40 +0200
Subject: [PATCH] Evaluate variables lazily

---
 run-exp.py | 99 ++++++++++++++++++++++++++++++++++++++----------------
 1 file changed, 70 insertions(+), 29 deletions(-)

diff --git a/run-exp.py b/run-exp.py
index c8b9453..83e3103 100755
--- a/run-exp.py
+++ b/run-exp.py
@@ -4,22 +4,23 @@ import itertools
 import pathlib
 import subprocess
 
+from abc import ABC, abstractmethod
+
 import click
 import yaml
 
 
-class Variable:
-    def __init__(self, name, values):
+class Variable(ABC):
+    def __init__(self, name):
         self._name = name
-        self._values = values
 
     @property
     def name(self):
         return self._name
 
-    @property
-    def values(self):
-        return self._values
+    @abstractmethod
+    def evaluate(self, variable_mapping):
+        pass
 
     @classmethod
     def new_scalar(cls, name, value):
@@ -29,34 +30,45 @@ class Variable:
     def new_multi_valued(cls, name, values):
         return cls(name, values)
 
-    @classmethod
-    def new(cls, name, value_config):
-        if isinstance(value_config, list):
-            return cls.new_multi_valued(name, value_config)
-        elif not isinstance(value_config, dict):  # scalar value
-            return cls.new_scalar(name, value_config)
 
-        var_type = value_config["type"]
-        assert var_type in ["file", "range"], "Unsupported variable type."
+class SimpleVariable(Variable):
+    def __init__(self, name, values):
+        super().__init__(name)
+        self._values = values
 
-        if var_type == "file":
-            directory = pathlib.Path(value_config["directory"]).expanduser()
-            assert directory.exists() and directory.is_dir(), f"No such directory: {directory}"
+    def evaluate(self, variable_mapping):
+        def substitute(v):
+            if isinstance(v, str):
+                return v.format(**variable_mapping)
+            else:
+                return v
 
-            values = sorted(list(directory.iterdir()))
+        return [substitute(v) for v in self._values]
 
-            if value_config.get("basename", False):
-                values = [v.name for v in values]
 
-        if var_type == "range":
-            assert "max" in value_config, f"Missing \"max\" value for range variable \"{name}\""
-            upper = value_config["max"]
-            lower = value_config.get("min", 0)
-            step = value_config.get("step", 1)
+class RangeVariable(SimpleVariable):
+    def __init__(self, name, lower, upper, step):
+        super().__init__(name, list(range(lower, upper + 1, step)))
+
+
+class FileVariable(Variable):
+    def __init__(self, name, directory_str, basename_only):
+        super().__init__(name)
+        self._directory_str = directory_str
+        self._basename_only = basename_only
+
+    def evaluate(self, variable_mapping):
+        directory = pathlib.Path(self._directory_str.format(**variable_mapping))
+        directory = directory.expanduser()
 
-            values = list(range(lower, upper + 1, step))
+        assert directory.exists() and directory.is_dir(), f"No such directory: {directory}"
 
-        return cls.new_multi_valued(name, values)
+        values = sorted(list(directory.iterdir()))
+
+        if self._basename_only:
+            values = [v.name for v in values]
+
+        return values
 
 
 class Task:
@@ -173,6 +185,33 @@ class Config:
     def tasks(self):
         return self._tasks
 
+    @staticmethod
+    def parse_var_definition(name, value_config):
+        if isinstance(value_config, list):
+            return SimpleVariable.new_multi_valued(name, value_config)
+        elif not isinstance(value_config, dict):  # scalar value
+            return SimpleVariable.new_scalar(name, value_config)
+
+        var_type = value_config["type"]
+        assert var_type in ["file", "range"], "Unsupported variable type."
+
+        if var_type == "file":
+            return FileVariable(
+                name,
+                value_config["directory"],
+                value_config.get("basename", False),
+            )
+
+        if var_type == "range":
+            assert "max" in value_config, f"Missing \"max\" value for range variable \"{name}\""
+
+            return RangeVariable(
+                name,
+                value_config.get("min", 0),
+                value_config["max"],
+                value_config.get("step", 1),
+            )
+
     @classmethod
     def from_file(cls, filepath: pathlib.Path):
         with filepath.open('r') as file:
@@ -180,7 +219,9 @@ class Config:
 
         name = raw_config.get("name", str(filepath))
         repetitions = raw_config.get("repetitions", 1)
-        variables = [Variable.new(k, v) for k, v in raw_config.get("variables", {}).items()]
+        variables = [
+            cls.parse_var_definition(k, v) for k, v in raw_config.get("variables", {}).items()
+        ]
         tasks = [Task.from_dict(t) for t in raw_config.get("tasks", [])]
 
         return cls(name, repetitions, variables, tasks)
@@ -200,7 +241,7 @@ def main(config_file, global_timeout, dry_run):
     print(f">     and {c.repetitions} repetitions")
     print()
 
-    variable_map = {v.name: v.values for v in c.variables}
+    variable_map = {v.name: v.evaluate({}) for v in c.variables}
     value_combinations = itertools.product(*variable_map.values())
     run_values = [dict(zip(variable_map.keys(), vals)) for vals in value_combinations]
 
-- 
GitLab