github-actions[bot] commited on
Commit
e83942e
·
1 Parent(s): 817f4f7

deploy: sync from GitHub 2026-05-13T14:25:24Z

Browse files
generate_heatmap.py CHANGED
@@ -8,6 +8,8 @@ import matplotlib.patches as mpatches
8
  from matplotlib.lines import Line2D
9
  import matplotlib.cm as cm
10
  import matplotlib.colors as mcolors
 
 
11
 
12
 
13
 
@@ -196,12 +198,13 @@ def generate_heatmap(voltages, output_path, vmin=0.92, vmax=1.06, map_image="13b
196
  plt.rcParams['svg.fonttype'] = 'none'
197
  plt.savefig(output_path, format='svg', bbox_inches='tight', dpi=150, facecolor='white')
198
  plt.close()
199
- print(f"[ok] saved {output_path}")
200
 
201
 
202
  if __name__ == "__main__":
203
  if len(sys.argv) < 15:
204
- print("Usage: generate_heatmap.py <output.png> <v1> <v2> ... <v13> [dc_bus_idx]")
 
205
  sys.exit(1)
206
  out = sys.argv[1]
207
  volts = [float(v) for v in sys.argv[2:15]]
 
8
  from matplotlib.lines import Line2D
9
  import matplotlib.cm as cm
10
  import matplotlib.colors as mcolors
11
+ import logging
12
+ logger = logging.getLogger(__name__)
13
 
14
 
15
 
 
198
  plt.rcParams['svg.fonttype'] = 'none'
199
  plt.savefig(output_path, format='svg', bbox_inches='tight', dpi=150, facecolor='white')
200
  plt.close()
201
+ logger.info(f"Saved heatmap: {output_path}")
202
 
203
 
204
  if __name__ == "__main__":
205
  if len(sys.argv) < 15:
206
+ logger.info(f"Usage: generate_heatmap.py <output.png> <v1> <v2> ... <v13> [dc_bus_idx]")
207
+
208
  sys.exit(1)
209
  out = sys.argv[1]
210
  volts = [float(v) for v in sys.argv[2:15]]
openg2g/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- """OpenG2G: GPU-to-Grid framework for distribution-level voltage regulation."""
2
-
3
- __version__ = "0.1.0"
 
 
 
 
openg2g/clock.py DELETED
@@ -1,91 +0,0 @@
1
- """Simulation clock with multi-rate support and optional live-mode wall-clock sync."""
2
-
3
- from __future__ import annotations
4
-
5
- import time
6
- import warnings
7
- from dataclasses import dataclass, field
8
- from fractions import Fraction
9
-
10
-
11
- @dataclass
12
- class SimulationClock:
13
- """Integer-tick clock that avoids floating-point drift.
14
-
15
- Components run at different rates (DC=0.1s, Grid=1.0s, Controller=1.0s or 60s).
16
- The coordinator computes `tick_s` as the GCD of all component periods.
17
-
18
- All time step parameters use `fractions.Fraction` for exact arithmetic.
19
- The `time_s` property returns `float` for compatibility with numpy/plotting.
20
-
21
- In live mode (`live=True`), the clock synchronizes with wall-clock time.
22
- If computation falls behind, a warning is issued.
23
-
24
- Attributes:
25
- tick_s: Duration of one tick as a `Fraction` (seconds).
26
- live: If `True`, synchronize with wall-clock time.
27
- """
28
-
29
- tick_s: Fraction
30
- live: bool = False
31
- _step: int = field(default=0, init=False, repr=False)
32
- _wall_t0: float | None = field(default=None, init=False, repr=False)
33
-
34
- def __post_init__(self) -> None:
35
- if not isinstance(self.tick_s, Fraction):
36
- raise TypeError(f"tick_s must be a Fraction, got {type(self.tick_s).__name__}")
37
- if self.tick_s <= 0:
38
- raise ValueError(f"tick_s must be positive, got {self.tick_s}")
39
-
40
- @property
41
- def time_s(self) -> float:
42
- return float(self._step * self.tick_s)
43
-
44
- @property
45
- def step(self) -> int:
46
- return self._step
47
-
48
- def advance(self) -> float:
49
- """Advance one tick.
50
-
51
- Returns:
52
- New simulation time in seconds.
53
- """
54
- self._step += 1
55
- if self.live:
56
- if self._wall_t0 is None:
57
- self._wall_t0 = time.monotonic()
58
- expected_wall = self._wall_t0 + self.time_s
59
- now = time.monotonic()
60
- if now < expected_wall:
61
- time.sleep(expected_wall - now)
62
- elif now - expected_wall > float(self.tick_s):
63
- lag = now - expected_wall
64
- warnings.warn(
65
- f"Clock lag: {lag:.3f}s behind wall time at sim t={self.time_s:.1f}s. "
66
- f"Control loop cannot keep up with real-time.",
67
- stacklevel=2,
68
- )
69
- return self.time_s
70
-
71
- def reset(self) -> None:
72
- """Reset clock to initial state (tick 0)."""
73
- self._step = 0
74
- self._wall_t0 = None
75
-
76
- def is_due(self, period_s: Fraction) -> bool:
77
- """Check if an event with the given period should fire on this tick.
78
-
79
- Returns:
80
- `True` if this tick is a multiple of the period.
81
-
82
- Raises:
83
- ValueError: If *period_s* is not an exact multiple of *tick_s*.
84
- """
85
- if period_s <= 0:
86
- raise ValueError(f"period_s must be positive, got {period_s}")
87
- ratio = period_s / self.tick_s
88
- if ratio.denominator != 1:
89
- raise ValueError(f"period_s={period_s} is not an exact multiple of tick_s={self.tick_s}")
90
- period_ticks = int(ratio)
91
- return self._step % period_ticks == 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/common.py DELETED
@@ -1,20 +0,0 @@
1
- """Cross-cutting types shared across component families."""
2
-
3
- from __future__ import annotations
4
-
5
- from dataclasses import dataclass
6
-
7
-
8
- @dataclass(frozen=True)
9
- class ThreePhase:
10
- """Three-phase quantity. Access via `.a`, `.b`, `.c`.
11
-
12
- Attributes:
13
- a: Phase A value.
14
- b: Phase B value.
15
- c: Phase C value.
16
- """
17
-
18
- a: float
19
- b: float
20
- c: float
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/controller/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Controllers receive datacenter and grid state and produce control actions."""
 
 
openg2g/controller/base.py DELETED
@@ -1,153 +0,0 @@
1
- """Abstract base class for controllers."""
2
-
3
- from __future__ import annotations
4
-
5
- from abc import ABC, abstractmethod
6
- from fractions import Fraction
7
- from typing import Generic, TypeVar, Union, final, get_args, get_origin
8
-
9
- from openg2g.clock import SimulationClock
10
- from openg2g.datacenter.base import DatacenterBackend
11
- from openg2g.datacenter.command import DatacenterCommand
12
- from openg2g.events import EventEmitter
13
- from openg2g.grid.base import GridBackend
14
- from openg2g.grid.command import GridCommand
15
-
16
- DCBackendT = TypeVar("DCBackendT", bound=DatacenterBackend)
17
- GridBackendT = TypeVar("GridBackendT", bound=GridBackend)
18
-
19
-
20
- def _normalize_backend_type_arg(
21
- arg: object,
22
- *,
23
- required_base: type[object],
24
- ) -> tuple[type[object], ...]:
25
- if isinstance(arg, type):
26
- if issubclass(arg, required_base):
27
- return (arg,)
28
- raise TypeError(f"Controller generic type {arg!r} is not a subclass of {required_base.__name__}.")
29
-
30
- origin = get_origin(arg)
31
-
32
- # Handle parameterized generics like DatacenterBackend[OfflineDatacenterState]
33
- if isinstance(origin, type) and issubclass(origin, required_base):
34
- return (origin,)
35
-
36
- if origin is Union:
37
- out: list[type[object]] = []
38
- for item in get_args(arg):
39
- item_type = item if isinstance(item, type) else get_origin(item)
40
- if not isinstance(item_type, type) or not issubclass(item_type, required_base):
41
- raise TypeError(f"Controller generic type {item!r} is not a subclass of {required_base.__name__}.")
42
- out.append(item_type)
43
- return tuple(out)
44
-
45
- raise TypeError(
46
- f"Unsupported controller generic type argument: {arg!r}. Use a concrete class (or Union of concrete classes)."
47
- )
48
-
49
-
50
- class Controller(Generic[DCBackendT, GridBackendT], ABC):
51
- """Interface for a control component in the G2G framework.
52
-
53
- Controllers receive datacenter and grid state and produce control actions.
54
- Multiple controllers compose in order within the coordinator.
55
- """
56
-
57
- _dc_types: tuple[type[DatacenterBackend], ...] = (DatacenterBackend,)
58
- _grid_types: tuple[type[GridBackend], ...] = (GridBackend,)
59
-
60
- def __init_subclass__(cls, **kwargs: object) -> None:
61
- super().__init_subclass__(**kwargs)
62
- dc_types: tuple[type[DatacenterBackend], ...] | None = None
63
- grid_types: tuple[type[GridBackend], ...] | None = None
64
- for base in getattr(cls, "__orig_bases__", ()):
65
- if get_origin(base) is Controller:
66
- args = get_args(base)
67
- if len(args) != 2:
68
- raise TypeError(
69
- f"{cls.__name__} must specialize Controller with two generic args: "
70
- "Controller[DatacenterType, GridType]."
71
- )
72
- dc_raw, grid_raw = args
73
- dc_norm = _normalize_backend_type_arg(dc_raw, required_base=DatacenterBackend)
74
- grid_norm = _normalize_backend_type_arg(grid_raw, required_base=GridBackend)
75
- dc_types = tuple(t for t in dc_norm if issubclass(t, DatacenterBackend))
76
- grid_types = tuple(t for t in grid_norm if issubclass(t, GridBackend))
77
- break
78
-
79
- if dc_types is None or grid_types is None:
80
- inherited = [b for b in cls.__bases__ if issubclass(b, Controller)]
81
- inherited = [b for b in inherited if b is not Controller]
82
- if inherited:
83
- parent = inherited[0]
84
- cls._dc_types = parent.compatible_datacenter_types()
85
- cls._grid_types = parent.compatible_grid_types()
86
- return
87
- raise TypeError(
88
- f"{cls.__name__} must explicitly specialize Controller generics as "
89
- "Controller[DatacenterType, GridType]."
90
- )
91
-
92
- cls._dc_types = dc_types
93
- cls._grid_types = grid_types
94
-
95
- @final
96
- @classmethod
97
- def compatible_datacenter_types(cls) -> tuple[type[DatacenterBackend], ...]:
98
- return cls._dc_types
99
-
100
- @final
101
- @classmethod
102
- def compatible_grid_types(cls) -> tuple[type[GridBackend], ...]:
103
- return cls._grid_types
104
-
105
- @final
106
- @classmethod
107
- def compatibility_signature(cls) -> str:
108
- dc = " | ".join(t.__name__ for t in cls.compatible_datacenter_types())
109
- grid = " | ".join(t.__name__ for t in cls.compatible_grid_types())
110
- return f"Controller[{dc}, {grid}]"
111
-
112
- @property
113
- @abstractmethod
114
- def dt_s(self) -> Fraction:
115
- """Control interval as a Fraction (seconds)."""
116
-
117
- @abstractmethod
118
- def reset(self) -> None:
119
- """Reset simulation state to initial conditions.
120
-
121
- Called by the coordinator before each [`start`][..start]. Must
122
- clear all simulation state: dual variables, counters, cached
123
- matrices. Configuration (dt_s, fits, step sizes) is not
124
- affected.
125
-
126
- Abstract so every implementation explicitly enumerates its state.
127
- A forgotten field is a bug -- not clearing it silently corrupts
128
- the second run.
129
- """
130
-
131
- def start(self) -> None:
132
- """Acquire per-run resources.
133
-
134
- Called after [`reset`][..reset], before the simulation loop.
135
- No-op by default because most controllers have no resources to
136
- acquire.
137
- """
138
-
139
- def stop(self) -> None:
140
- """Release per-run resources. Simulation state is preserved.
141
-
142
- Called after the simulation loop in LIFO order. No-op by default.
143
- """
144
-
145
- @abstractmethod
146
- def step(
147
- self,
148
- clock: SimulationClock,
149
- datacenter: DCBackendT,
150
- grid: GridBackendT,
151
- events: EventEmitter,
152
- ) -> list[DatacenterCommand | GridCommand]:
153
- """Compute control commands for this step. Return an empty list for no-op."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/controller/batch_size_schedule.py DELETED
@@ -1,159 +0,0 @@
1
- """Batch size schedule controller: applies pre-defined batch size changes at specified times."""
2
-
3
- from __future__ import annotations
4
-
5
- from collections.abc import Iterator
6
- from dataclasses import dataclass
7
- from fractions import Fraction
8
-
9
- from openg2g.clock import SimulationClock
10
- from openg2g.controller.base import Controller
11
- from openg2g.datacenter.base import DatacenterBackend
12
- from openg2g.datacenter.command import DatacenterCommand, SetBatchSize
13
- from openg2g.events import EventEmitter
14
- from openg2g.grid.base import GridBackend
15
- from openg2g.grid.command import GridCommand
16
-
17
-
18
- @dataclass(frozen=True)
19
- class BatchSizeChange:
20
- """A batch size change event, optionally with gradual ramp-up.
21
-
22
- Attributes:
23
- batch_size: Target batch size (max_num_seqs).
24
- ramp_up_rate: Requests/second ramp-up rate. 0 means immediate.
25
- """
26
-
27
- batch_size: int
28
- ramp_up_rate: float = 0.0
29
-
30
- def __post_init__(self) -> None:
31
- if self.batch_size <= 0:
32
- raise ValueError(f"batch_size must be positive, got {self.batch_size}.")
33
- if self.ramp_up_rate < 0:
34
- raise ValueError(f"ramp_up_rate must be >= 0, got {self.ramp_up_rate}.")
35
-
36
- def at(self, t: float) -> BatchSizeSchedule:
37
- """Schedule this change at time *t* seconds.
38
-
39
- Returns:
40
- A single-entry [`BatchSizeSchedule`][...BatchSizeSchedule].
41
- """
42
- return BatchSizeSchedule(((t, self),))
43
-
44
-
45
- class BatchSizeSchedule:
46
- """Ordered sequence of batch size changes, built with `|` operator.
47
-
48
- Example:
49
-
50
- ```python
51
- schedule = (
52
- BatchSizeChange(48).at(40)
53
- | BatchSizeChange(32).at(60)
54
- | BatchSizeChange(48, ramp_up_rate=4).at(280)
55
- )
56
- ```
57
-
58
- Raises:
59
- ValueError: If two entries share the same timestamp.
60
- """
61
-
62
- __slots__ = ("_entries",)
63
-
64
- def __init__(self, entries: tuple[tuple[float, BatchSizeChange], ...]) -> None:
65
- self._entries = tuple(sorted(entries, key=lambda e: e[0]))
66
- times = [t for t, _ in self._entries]
67
- if len(times) != len(set(times)):
68
- seen: set[float] = set()
69
- dupes = sorted({t for t in times if t in seen or seen.add(t)})
70
- raise ValueError(f"BatchSizeSchedule has duplicate timestamps: {dupes}")
71
-
72
- def __or__(self, other: BatchSizeSchedule) -> BatchSizeSchedule:
73
- return BatchSizeSchedule(self._entries + other._entries)
74
-
75
- def __iter__(self) -> Iterator[tuple[float, BatchSizeChange]]:
76
- return iter(self._entries)
77
-
78
- def __len__(self) -> int:
79
- return len(self._entries)
80
-
81
- def __bool__(self) -> bool:
82
- return bool(self._entries)
83
-
84
- def __repr__(self) -> str:
85
- parts: list[str] = []
86
- for t, c in self._entries:
87
- ramp = f", ramp_up_rate={c.ramp_up_rate}" if c.ramp_up_rate > 0 else ""
88
- parts.append(f"BatchSizeChange({c.batch_size}{ramp}).at(t={t})")
89
- return " | ".join(parts)
90
-
91
-
92
- class BatchSizeScheduleController(Controller[DatacenterBackend, GridBackend]):
93
- """Applies pre-defined batch size changes at scheduled times.
94
-
95
- Walks each model's schedule and emits
96
- [`SetBatchSize`][openg2g.datacenter.command.SetBatchSize] commands when the
97
- simulation clock reaches the scheduled time.
98
-
99
- Args:
100
- schedules: Per-model batch size schedules, keyed by model label.
101
- dt_s: How often the controller checks the schedule (seconds).
102
- """
103
-
104
- def __init__(
105
- self,
106
- *,
107
- schedules: dict[str, BatchSizeSchedule],
108
- dt_s: Fraction = Fraction(1),
109
- ) -> None:
110
- self._dt_s = dt_s
111
- self._schedules = dict(schedules)
112
- self._indices: dict[str, int] = {label: 0 for label in schedules}
113
-
114
- def reset(self) -> None:
115
- self._indices = {label: 0 for label in self._schedules}
116
-
117
- @property
118
- def dt_s(self) -> Fraction:
119
- return self._dt_s
120
-
121
- def step(
122
- self,
123
- clock: SimulationClock,
124
- datacenter: DatacenterBackend,
125
- grid: GridBackend,
126
- events: EventEmitter,
127
- ) -> list[DatacenterCommand | GridCommand]:
128
- t_now = clock.time_s
129
- batch_changes: dict[str, int] = {}
130
- ramp_rates: dict[str, float] = {}
131
-
132
- for label, schedule in self._schedules.items():
133
- entries = list(schedule)
134
- idx = self._indices[label]
135
-
136
- while idx < len(entries):
137
- t_ev, change = entries[idx]
138
- if float(t_ev) <= t_now + 1e-12:
139
- batch_changes[label] = change.batch_size
140
- if change.ramp_up_rate > 0:
141
- ramp_rates[label] = change.ramp_up_rate
142
- idx += 1
143
- else:
144
- break
145
-
146
- self._indices[label] = idx
147
-
148
- if batch_changes:
149
- events.emit(
150
- "controller.batch_schedule.fired",
151
- {"batch_size_by_model": batch_changes},
152
- )
153
- return [
154
- SetBatchSize(
155
- batch_size_by_model=batch_changes,
156
- ramp_up_rate_by_model=ramp_rates,
157
- )
158
- ]
159
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/controller/noop.py DELETED
@@ -1,36 +0,0 @@
1
- """No-op controller that does nothing."""
2
-
3
- from __future__ import annotations
4
-
5
- from fractions import Fraction
6
-
7
- from openg2g.clock import SimulationClock
8
- from openg2g.controller.base import Controller
9
- from openg2g.datacenter.base import DatacenterBackend
10
- from openg2g.datacenter.command import DatacenterCommand
11
- from openg2g.events import EventEmitter
12
- from openg2g.grid.base import GridBackend
13
- from openg2g.grid.command import GridCommand
14
-
15
-
16
- class NoopController(Controller[DatacenterBackend, GridBackend]):
17
- """Controller that always returns an empty action."""
18
-
19
- def __init__(self, dt_s: Fraction = Fraction(1)) -> None:
20
- self._dt_s = dt_s
21
-
22
- @property
23
- def dt_s(self) -> Fraction:
24
- return self._dt_s
25
-
26
- def reset(self) -> None:
27
- pass
28
-
29
- def step(
30
- self,
31
- clock: SimulationClock,
32
- datacenter: DatacenterBackend,
33
- grid: GridBackend,
34
- events: EventEmitter,
35
- ) -> list[DatacenterCommand | GridCommand]:
36
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/controller/ofo.py DELETED
@@ -1,793 +0,0 @@
1
- """Online Feedback Optimization (OFO) batch-size controller.
2
-
3
- Implements the primal-dual algorithm for joint voltage regulation and
4
- latency management via GPU batch size control.
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- import bisect
10
- import logging
11
- import math
12
- from fractions import Fraction
13
- from pathlib import Path
14
- from typing import Any
15
-
16
- import numpy as np
17
- import pandas as pd
18
- from mlenergy_data.modeling import LogisticModel
19
- from mlenergy_data.records import LLMRuns
20
- from pydantic import BaseModel, ConfigDict
21
-
22
- from openg2g.clock import SimulationClock
23
- from openg2g.controller.base import Controller
24
- from openg2g.datacenter.base import LLMBatchSizeControlledDatacenter, LLMDatacenterState
25
- from openg2g.datacenter.command import DatacenterCommand, SetBatchSize
26
- from openg2g.datacenter.config import InferenceModelSpec
27
- from openg2g.events import EventEmitter
28
- from openg2g.grid.command import GridCommand
29
- from openg2g.grid.opendss import OpenDSSGrid
30
-
31
- logger = logging.getLogger(__name__)
32
-
33
-
34
- class OFOConfig(BaseModel):
35
- """Online Feedback Optimization tuning parameters.
36
-
37
- Attributes:
38
- primal_step_size: Primal descent step size ρ_x (Eq. 8).
39
- w_throughput: Throughput weight in primal gradient.
40
- w_switch: Switching cost regularizer weight γ (Eq. 4a).
41
- voltage_gradient_scale: Scaling factor k_v for voltage dual term
42
- in the primal gradient.
43
- v_min: Lower voltage bound (pu).
44
- v_max: Upper voltage bound (pu).
45
- voltage_dual_step_size: Voltage dual ascent step size ρ_v (Eqs. 5-6).
46
- latency_dual_step_size: Latency dual ascent step size ρ_l (Eq. 7).
47
- sensitivity_update_interval: Steps between H-matrix re-estimation
48
- (0 = only once at init).
49
- sensitivity_perturbation_kw: Perturbation magnitude (kW) for
50
- finite-difference sensitivity estimation.
51
- """
52
-
53
- model_config = ConfigDict(frozen=True)
54
-
55
- # Primal
56
- primal_step_size: float = 0.05
57
- w_throughput: float = 0.1
58
- w_switch: float = 0.0
59
- voltage_gradient_scale: float = 1e6
60
-
61
- # Dual
62
- v_min: float = 0.95
63
- v_max: float = 1.05
64
- voltage_dual_step_size: float = 0.5
65
- latency_dual_step_size: float = 1.0
66
-
67
- # Sensitivity
68
- sensitivity_update_interval: int = 0
69
- sensitivity_perturbation_kw: float = 100.0
70
-
71
-
72
- class LogisticModelStore:
73
- """Per-model logistic models for power, latency, and throughput.
74
-
75
- Used by
76
- [`OFOBatchSizeController`][openg2g.controller.ofo.OFOBatchSizeController]
77
- to compute gradients of the Lagrangian with respect to batch size.
78
-
79
- Attributes:
80
- COL_MODEL_LABEL: Column name for model label in the CSV.
81
- COL_METRIC: Column name for metric type in the CSV.
82
- """
83
-
84
- COL_MODEL_LABEL = "model_label"
85
- COL_METRIC = "metric"
86
-
87
- def __init__(
88
- self,
89
- power: dict[str, LogisticModel],
90
- latency: dict[str, LogisticModel],
91
- throughput: dict[str, LogisticModel],
92
- ) -> None:
93
- self._power = dict(power)
94
- self._latency = dict(latency)
95
- self._throughput = dict(throughput)
96
- self._by_batch: dict[str, dict[int, list[tuple[float, float, float]]]] | None = None
97
-
98
- def power(self, model: str) -> LogisticModel:
99
- """Return the power logistic model for a model label."""
100
- return self._power[model]
101
-
102
- def latency(self, model: str) -> LogisticModel:
103
- """Return the latency logistic model for a model label."""
104
- return self._latency[model]
105
-
106
- def throughput(self, model: str) -> LogisticModel:
107
- """Return the throughput logistic model for a model label."""
108
- return self._throughput[model]
109
-
110
- @property
111
- def power_fits(self) -> dict[str, LogisticModel]:
112
- return dict(self._power)
113
-
114
- @property
115
- def latency_fits(self) -> dict[str, LogisticModel]:
116
- return dict(self._latency)
117
-
118
- @property
119
- def throughput_fits(self) -> dict[str, LogisticModel]:
120
- return dict(self._throughput)
121
-
122
- @classmethod
123
- def generate(
124
- cls,
125
- models: tuple[InferenceModelSpec, ...],
126
- data_sources: dict[str, Any],
127
- *,
128
- runs: Any = None,
129
- mlenergy_data_dir: Path | None = None,
130
- ) -> LogisticModelStore:
131
- """Generate logistic fits from ML.ENERGY benchmark data.
132
-
133
- Args:
134
- models: Model specifications.
135
- data_sources: Per-model `MLEnergySource` instances, keyed by
136
- `model_label`.
137
- runs: Pre-loaded `LLMRuns` object. If `None`, loads from
138
- `mlenergy_data_dir` or the HuggingFace Hub.
139
- mlenergy_data_dir: Path to compiled mlenergy-data directory.
140
- Ignored if `runs` is provided.
141
-
142
- Returns:
143
- A new `LogisticModelStore` with fitted logistic models.
144
- """
145
- if runs is None:
146
- unique_tasks = {src.task for src in data_sources.values()}
147
- if mlenergy_data_dir:
148
- runs = LLMRuns.from_directory(str(mlenergy_data_dir), stable_only=False).task(*unique_tasks)
149
- else:
150
- runs = LLMRuns.from_hf(stable_only=False).task(*unique_tasks)
151
- if not runs:
152
- raise ValueError("No runs found for the specified tasks")
153
-
154
- subsets_by_label: dict[str, Any] = {}
155
- for ms in models:
156
- src = data_sources.get(ms.model_label)
157
- if src is None:
158
- raise ValueError(f"No data source for model {ms.model_label!r}")
159
- model_id = ms.model_id
160
- if not model_id:
161
- raise ValueError(f"model_id is required for data generation (model={ms.model_label!r})")
162
-
163
- subset = (
164
- runs.model_id(model_id).gpu_model(src.gpu).num_gpus(ms.gpus_per_replica).max_num_seqs(*src.batch_sizes)
165
- )
166
- if not subset:
167
- raise ValueError(
168
- f"Config matched zero runs for logistic fits: model_id={model_id!r}, "
169
- f"gpu={src.gpu!r}, num_gpus={ms.gpus_per_replica}, "
170
- f"batch_sizes={src.batch_sizes}"
171
- )
172
- subsets_by_label[ms.model_label] = subset
173
-
174
- all_by_batch: dict[str, dict[int, list[tuple[float, float, float]]]] = {}
175
- power: dict[str, LogisticModel] = {}
176
- latency: dict[str, LogisticModel] = {}
177
- throughput: dict[str, LogisticModel] = {}
178
- for model_label, group in subsets_by_label.items():
179
- exclude = set(data_sources[model_label].fit_exclude_batch_sizes)
180
- by_batch: dict[int, list[tuple[float, float, float]]] = {}
181
- for r in group:
182
- if r.max_num_seqs in exclude:
183
- continue
184
- by_batch.setdefault(r.max_num_seqs, []).append(
185
- (r.avg_power_watts, r.mean_itl_ms / 1000.0, r.output_throughput_tokens_per_sec)
186
- )
187
- all_by_batch[model_label] = by_batch
188
-
189
- batches = sorted(by_batch.keys())
190
- if not batches:
191
- continue
192
-
193
- x = np.log2(np.array(batches, dtype=float).clip(min=1))
194
- for _metric_name, idx, target in [
195
- ("power", 0, power),
196
- ("latency", 1, latency),
197
- ("throughput", 2, throughput),
198
- ]:
199
- y = np.array([float(np.median([t[idx] for t in by_batch[b]])) for b in batches])
200
- fit = LogisticModel.fit(x, y)
201
- target[model_label] = fit
202
-
203
- if not power and not latency and not throughput:
204
- raise ValueError("No logistic fit rows produced")
205
- store = cls(power=power, latency=latency, throughput=throughput)
206
- store._by_batch = all_by_batch
207
- return store
208
-
209
- def save(self, csv_path: Path, *, plot: bool = False) -> None:
210
- """Save logistic fits to a CSV.
211
-
212
- Args:
213
- csv_path: Output CSV path.
214
- plot: If `True`, also write a logistic fits plot to the
215
- same directory.
216
- """
217
- csv_path = Path(csv_path)
218
- csv_path.parent.mkdir(parents=True, exist_ok=True)
219
- rows: list[dict[str, Any]] = []
220
- for metric_name, fits in [("power", self._power), ("latency", self._latency), ("throughput", self._throughput)]:
221
- for label in sorted(fits):
222
- model = fits[label]
223
- rows.append(
224
- {
225
- self.COL_MODEL_LABEL: label,
226
- self.COL_METRIC: metric_name,
227
- "L": model.L,
228
- "x0": model.x0,
229
- "k": model.k,
230
- "b0": model.b0,
231
- }
232
- )
233
- pd.DataFrame(rows).to_csv(csv_path, index=False)
234
-
235
- by_batch = getattr(self, "_by_batch", None)
236
- if plot and by_batch is not None:
237
- model_labels = sorted(self._power.keys())
238
- _plot_logistic_fits(
239
- by_batch,
240
- self._power,
241
- self._latency,
242
- self._throughput,
243
- model_labels,
244
- csv_path.parent,
245
- )
246
-
247
- @classmethod
248
- def load(cls, csv_path: Path | str) -> LogisticModelStore:
249
- """Load power, latency, and throughput fits from a merged CSV.
250
-
251
- Expected columns: `model_label`, `metric`, plus the logistic
252
- model parameter columns (`L`, `x0`, `k`, `b0`).
253
-
254
- The `metric` column must contain `power`, `latency`, or
255
- `throughput` (case-insensitive).
256
-
257
- Args:
258
- csv_path: Path to the logistic fits CSV.
259
- """
260
- csv_path = Path(csv_path)
261
- df = pd.read_csv(csv_path)
262
-
263
- required_cols = [cls.COL_MODEL_LABEL, cls.COL_METRIC]
264
- missing = [c for c in required_cols if c not in df.columns]
265
- if missing:
266
- raise ValueError(f"{csv_path} missing columns: {missing}. Got: {list(df.columns)}")
267
-
268
- power: dict[str, LogisticModel] = {}
269
- latency: dict[str, LogisticModel] = {}
270
- throughput: dict[str, LogisticModel] = {}
271
- targets = {"power": power, "latency": latency, "throughput": throughput}
272
- for row in df.to_dict(orient="records"):
273
- metric = str(row[cls.COL_METRIC]).strip().lower()
274
- if metric in targets:
275
- targets[metric][str(row[cls.COL_MODEL_LABEL])] = LogisticModel.from_dict(row)
276
-
277
- if not power and not latency and not throughput:
278
- raise ValueError(f"No logistic model rows loaded from {csv_path}")
279
- return cls(power=power, latency=latency, throughput=throughput)
280
-
281
- @classmethod
282
- def ensure(
283
- cls,
284
- csv_path: Path,
285
- models: tuple[InferenceModelSpec, ...] | None = None,
286
- data_sources: dict[str, Any] | None = None,
287
- *,
288
- mlenergy_data_dir: Path | None = None,
289
- plot: bool = False,
290
- ) -> LogisticModelStore:
291
- """Load from `csv_path`, generating first if needed.
292
-
293
- Args:
294
- csv_path: Path to the logistic fits CSV.
295
- models: Model specifications. Required when no cached file exists.
296
- data_sources: Per-model `MLEnergySource` instances, keyed by
297
- `model_label`. Required when no cached file exists.
298
- mlenergy_data_dir: Path to compiled mlenergy-data directory.
299
- plot: If `True`, generate a logistic fits plot on generation.
300
- """
301
- csv_path = Path(csv_path)
302
- if not csv_path.exists():
303
- if models is None or data_sources is None:
304
- raise ValueError("models and data_sources required for LogisticModelStore generation (no cached data)")
305
- logger.info("Generating logistic fits to %s ...", csv_path)
306
- cls.generate(models, data_sources, mlenergy_data_dir=mlenergy_data_dir).save(csv_path, plot=plot)
307
- return cls.load(csv_path)
308
-
309
-
310
- class VoltageDualVariables:
311
- """Full-network duals for voltage box constraints.
312
-
313
- Maintains per-bus dual variables for under- and overvoltage and updates
314
- them via projected gradient ascent:
315
-
316
- dual_undervoltage <- [dual_undervoltage + ρ_v * (v_min - v̂)]+
317
- dual_overvoltage <- [dual_overvoltage + ρ_v * (v̂ - v_max)]+
318
-
319
- Args:
320
- n_bus_phases: Number of bus-phase pairs in the voltage vector (3M).
321
- config: OFO configuration (voltage bounds and dual step size).
322
- """
323
-
324
- def __init__(self, n_bus_phases: int, config: OFOConfig) -> None:
325
- self.config = config
326
- self.dual_undervoltage = np.zeros(int(n_bus_phases), dtype=float) # λ in G2G paper Eq. 5
327
- self.dual_overvoltage = np.zeros(int(n_bus_phases), dtype=float) # λ̄ in G2G paper Eq. 6
328
-
329
- def update(self, observed_voltages: np.ndarray) -> None:
330
- """Update duals given observed voltage vector.
331
-
332
- Args:
333
- observed_voltages: Observed voltage magnitudes (pu), shape
334
- `(n_bus_phases,)`.
335
-
336
- Raises:
337
- ValueError: If `observed_voltages` length does not match the dual
338
- dimension.
339
- """
340
- observed_voltages = np.asarray(observed_voltages, float).reshape(-1)
341
- if observed_voltages.shape[0] != self.dual_undervoltage.shape[0]:
342
- raise ValueError(
343
- f"observed_voltages has len {observed_voltages.shape[0]} "
344
- f"but duals have len {self.dual_undervoltage.shape[0]}"
345
- )
346
- vmin = float(self.config.v_min)
347
- vmax = float(self.config.v_max)
348
- rho = float(self.config.voltage_dual_step_size)
349
- self.dual_undervoltage = np.maximum(self.dual_undervoltage + rho * (vmin - observed_voltages), 0.0)
350
- self.dual_overvoltage = np.maximum(self.dual_overvoltage + rho * (observed_voltages - vmax), 0.0)
351
-
352
- def dual_difference(self) -> np.ndarray:
353
- """Return the voltage dual difference (η = λ̄ − λ, Appendix B)."""
354
- return self.dual_overvoltage - self.dual_undervoltage
355
-
356
-
357
- class PrimalBatchOptimizer:
358
- """Primal batch-size optimizer operating in log2 space.
359
-
360
- Maintains continuous state `x_i = log2(batch_i)` per model and applies
361
- a gradient descent step using voltage duals, latency duals, and fitted
362
- power/latency/throughput curves.
363
-
364
- Args:
365
- models: Model specifications for each served model.
366
- feasible_batch_sizes: Allowed batch sizes (union across all models).
367
- power_fits: Per-model logistic fit for power vs log2(batch_size).
368
- latency_fits: Per-model logistic fit for latency vs log2(batch_size).
369
- throughput_fits: Per-model logistic fit for throughput vs
370
- log2(batch_size).
371
- config: OFO configuration (step size, throughput/switch weights,
372
- voltage gradient scale).
373
- """
374
-
375
- def __init__(
376
- self,
377
- *,
378
- models: list[InferenceModelSpec],
379
- feasible_batch_sizes: list[int],
380
- power_fits: dict[str, LogisticModel],
381
- latency_fits: dict[str, LogisticModel],
382
- throughput_fits: dict[str, LogisticModel],
383
- config: OFOConfig,
384
- ) -> None:
385
- self.models = list(models)
386
- self.feasible_batch_sizes = sorted({int(b) for b in feasible_batch_sizes})
387
- if not self.feasible_batch_sizes:
388
- raise ValueError("feasible_batch_sizes cannot be empty.")
389
-
390
- self.power_fits = power_fits
391
- self.latency_fits = latency_fits
392
- self.throughput_fits = throughput_fits
393
- self.config = config
394
-
395
- self.log_batch_size_min = math.log2(min(self.feasible_batch_sizes))
396
- self.log_batch_size_max = math.log2(max(self.feasible_batch_sizes))
397
-
398
- self.log_batch_size_by_model: dict[str, float] = {
399
- ms.model_label: float(self.log_batch_size_max) for ms in self.models
400
- }
401
- self.prev_log_batch_size_by_model: dict[str, float] = dict(self.log_batch_size_by_model)
402
-
403
- # Per-model throughput normalization: r_i(x_max) for a single replica
404
- self.throughput_max_by_model: dict[str, float] = {}
405
- b_max = int(max(self.feasible_batch_sizes))
406
- for ms in self.models:
407
- label = ms.model_label
408
- try:
409
- th_max = float(self.throughput_fits[label].eval(b_max))
410
- except Exception:
411
- th_max = float("nan")
412
- if (not np.isfinite(th_max)) or (th_max <= 0.0):
413
- th_max = 1.0
414
- self.throughput_max_by_model[label] = th_max
415
-
416
- def _clamp_log_batch_size(self, log_batch_size: float) -> float:
417
- return float(min(max(float(log_batch_size), self.log_batch_size_min), self.log_batch_size_max))
418
-
419
- def _discretize_batch(self, log_batch_size: float) -> int:
420
- b_cont = 2.0 ** float(log_batch_size)
421
- idx = bisect.bisect_left(self.feasible_batch_sizes, b_cont)
422
- candidates = []
423
- if idx > 0:
424
- candidates.append(self.feasible_batch_sizes[idx - 1])
425
- if idx < len(self.feasible_batch_sizes):
426
- candidates.append(self.feasible_batch_sizes[idx])
427
- return int(min(candidates, key=lambda bb: abs(bb - b_cont)))
428
-
429
- def init_from_batches(self, batch_init: dict[str, int]) -> None:
430
- """Initialize log-batch-size state from discrete batch sizes."""
431
- for ms in self.models:
432
- label = ms.model_label
433
- b = int(batch_init.get(label, max(self.feasible_batch_sizes)))
434
- log_batch_size = math.log2(max(b, 1))
435
- log_batch_size = self._clamp_log_batch_size(log_batch_size)
436
- self.log_batch_size_by_model[label] = float(log_batch_size)
437
- self.prev_log_batch_size_by_model[label] = float(log_batch_size)
438
-
439
- def step(
440
- self,
441
- *,
442
- voltage_dual_diff: np.ndarray,
443
- sensitivity_matrix: np.ndarray,
444
- phase_share_by_model: dict[str, np.ndarray],
445
- latency_dual_by_model: dict[str, float] | None = None,
446
- replica_count_by_model: dict[str, float] | None = None,
447
- ) -> dict[str, int]:
448
- """Primal gradient descent step.
449
-
450
- Args:
451
- voltage_dual_diff: Voltage dual difference vector
452
- (η = λ̄ − λ), shape `(n_bus_phases,)`.
453
- sensitivity_matrix: Voltage sensitivity matrix (H = dv/dp),
454
- shape `(n_bus_phases, 3)`.
455
- phase_share_by_model: Per-model normalized phase share vectors,
456
- shape `(3,)` each.
457
- latency_dual_by_model: Per-model latency dual variables (μ_i).
458
- replica_count_by_model: Per-model active replica counts (w_i).
459
-
460
- Returns:
461
- Next batch sizes per model.
462
- """
463
- voltage_dual_diff = np.asarray(voltage_dual_diff, float).reshape(-1)
464
- sensitivity_matrix = np.asarray(sensitivity_matrix, float)
465
- latency_dual_by_model = {} if latency_dual_by_model is None else dict(latency_dual_by_model)
466
- replica_count_by_model = {} if replica_count_by_model is None else dict(replica_count_by_model)
467
-
468
- step_size = float(self.config.primal_step_size) # ρ_x
469
- w_throughput = float(self.config.w_throughput)
470
- w_switch = float(self.config.w_switch)
471
- voltage_gradient_scale = float(self.config.voltage_gradient_scale)
472
-
473
- batch_next: dict[str, int] = {}
474
-
475
- for ms in self.models:
476
- label = ms.model_label
477
- log_batch_size = float(self.log_batch_size_by_model[label])
478
- prev_log_batch_size = float(self.prev_log_batch_size_by_model.get(label, log_batch_size))
479
-
480
- replica_count = float(replica_count_by_model.get(label, 0.0)) # w_i
481
- if (not np.isfinite(replica_count)) or (replica_count < 0.0):
482
- replica_count = 0.0
483
-
484
- phase_share = np.asarray( # e_i (phase-allocation weight, p.7)
485
- phase_share_by_model.get(label, np.array([1 / 3, 1 / 3, 1 / 3], dtype=float)),
486
- float,
487
- ).reshape(3)
488
- s = float(np.sum(phase_share))
489
- if (not np.isfinite(s)) or s <= 0.0:
490
- phase_share = np.array([1 / 3, 1 / 3, 1 / 3], dtype=float)
491
- else:
492
- phase_share = phase_share / s
493
-
494
- weighted_sensitivity = sensitivity_matrix @ phase_share # H @ e_i
495
- voltage_gradient = float(voltage_dual_diff @ weighted_sensitivity)
496
-
497
- dPdx_1 = float(self.power_fits[label].deriv_wrt_x(log_batch_size))
498
- dLdx_1 = float(self.latency_fits[label].deriv_wrt_x(log_batch_size))
499
- dThdx_1 = float(self.throughput_fits[label].deriv_wrt_x(log_batch_size))
500
-
501
- dPdx_1_kw = dPdx_1 / 1000.0
502
-
503
- th_max = float(self.throughput_max_by_model.get(label, 1.0))
504
- if (not np.isfinite(th_max)) or (th_max <= 0.0):
505
- th_max = 1.0
506
- dThdx_norm_1 = dThdx_1 / th_max
507
-
508
- dPdx = replica_count * dPdx_1_kw
509
- dThdx = replica_count * dThdx_norm_1
510
- dLdx = dLdx_1
511
-
512
- latency_dual = float(latency_dual_by_model.get(label, 0.0)) # μ_i
513
- if (not np.isfinite(latency_dual)) or (latency_dual < 0.0):
514
- latency_dual = 0.0
515
-
516
- # Gradient of the Lagrangian w.r.t. x_i = log2(batch_i).
517
- # G2G paper Eq. 18: nabla_x L = -dR/dx (throughput)
518
- # + 2*gamma*(x - x_prev) (switching)
519
- # + eta^T H e_i dP/dx (voltage dual)
520
- # + mu_i * dL/dx (latency dual)
521
- # Implementation extensions: wT scaling on throughput,
522
- # k_v scaling on voltage term
523
- grad = 0.0
524
- grad -= w_throughput * dThdx
525
- grad += voltage_gradient_scale * voltage_gradient * dPdx
526
- grad += latency_dual * dLdx
527
- grad += w_switch * (log_batch_size - prev_log_batch_size)
528
-
529
- new_log_batch_size = self._clamp_log_batch_size(log_batch_size - step_size * grad)
530
- self.prev_log_batch_size_by_model[label] = log_batch_size
531
- self.log_batch_size_by_model[label] = new_log_batch_size
532
- batch_next[label] = self._discretize_batch(new_log_batch_size)
533
-
534
- return batch_next
535
-
536
-
537
- class OFOBatchSizeController(Controller[LLMBatchSizeControlledDatacenter[LLMDatacenterState], OpenDSSGrid]):
538
- """Online Feedback Optimization controller for batch-size regulation.
539
-
540
- Reads grid voltage and datacenter state, updates voltage and latency
541
- duals, runs the primal batch-size optimizer, and returns new batch
542
- sizes. Latency dual updates use [`dc_state.observed_itl_s_by_model`
543
- ][openg2g.datacenter.base.LLMDatacenterState.observed_itl_s_by_model].
544
-
545
- Args:
546
- inference_models: Model specifications served in the datacenter.
547
- models: Per-model logistic models for power, latency, and
548
- throughput used in gradient computation.
549
- config: Unified OFO tuning parameters.
550
- dt_s: Control interval (seconds).
551
- """
552
-
553
- def __init__(
554
- self,
555
- inference_models: tuple[InferenceModelSpec, ...],
556
- *,
557
- models: LogisticModelStore,
558
- config: OFOConfig | None = None,
559
- dt_s: Fraction = Fraction(1),
560
- ) -> None:
561
- if config is None:
562
- config = OFOConfig()
563
-
564
- if not inference_models:
565
- raise ValueError("inference_models must not be empty.")
566
- labels = [ms.model_label for ms in inference_models]
567
- if len(labels) != len(set(labels)):
568
- raise ValueError(f"Duplicate model labels: {labels}")
569
-
570
- model_specs = list(inference_models)
571
-
572
- for ms in model_specs:
573
- label = ms.model_label
574
- for metric_name, accessor in [
575
- ("power", models.power),
576
- ("latency", models.latency),
577
- ("throughput", models.throughput),
578
- ]:
579
- try:
580
- accessor(label)
581
- except KeyError:
582
- raise ValueError(f"LogisticModelStore missing {metric_name} model for {label!r}.") from None
583
-
584
- self._dt_s = dt_s
585
- self._models = model_specs
586
- self._config = config
587
- self._itl_deadline_by_model = {ms.model_label: ms.itl_deadline_s for ms in model_specs}
588
-
589
- self._voltage_dual: VoltageDualVariables | None = None
590
- self._latency_dual_by_model: dict[str, float] = {ms.model_label: 0.0 for ms in model_specs}
591
-
592
- all_bs: set[int] = set()
593
- for ms in model_specs:
594
- all_bs.update(ms.feasible_batch_sizes)
595
- feasible_batch_sizes = sorted(all_bs)
596
-
597
- self._optimizer = PrimalBatchOptimizer(
598
- models=model_specs,
599
- feasible_batch_sizes=feasible_batch_sizes,
600
- power_fits=models.power_fits,
601
- latency_fits=models.latency_fits,
602
- throughput_fits=models.throughput_fits,
603
- config=config,
604
- )
605
- self._optimizer.init_from_batches({ms.model_label: ms.initial_batch_size for ms in model_specs})
606
-
607
- self._sensitivity_matrix: np.ndarray | None = None
608
- self._control_step_count: int = 0
609
-
610
- logger.info(
611
- "OFOBatchSizeController: %d models, dt=%s s, feasible_batches=%s",
612
- len(model_specs),
613
- dt_s,
614
- feasible_batch_sizes,
615
- )
616
-
617
- def reset(self) -> None:
618
- self._voltage_dual = None
619
- self._latency_dual_by_model = {ms.model_label: 0.0 for ms in self._models}
620
- self._optimizer.init_from_batches({ms.model_label: ms.initial_batch_size for ms in self._models})
621
- self._sensitivity_matrix = None
622
- self._control_step_count = 0
623
-
624
- @property
625
- def dt_s(self) -> Fraction:
626
- return self._dt_s
627
-
628
- def step(
629
- self,
630
- clock: SimulationClock,
631
- datacenter: LLMBatchSizeControlledDatacenter[LLMDatacenterState],
632
- grid: OpenDSSGrid,
633
- events: EventEmitter,
634
- ) -> list[DatacenterCommand | GridCommand]:
635
-
636
- if self._voltage_dual is None:
637
- self._voltage_dual = VoltageDualVariables(len(grid.v_index), self._config)
638
-
639
- # 1. Re-estimate sensitivity if needed
640
- if self._sensitivity_matrix is None or (
641
- self._config.sensitivity_update_interval > 0
642
- and self._control_step_count % self._config.sensitivity_update_interval == 0
643
- ):
644
- self._sensitivity_matrix, _ = grid.estimate_sensitivity(self._config.sensitivity_perturbation_kw)
645
-
646
- # 2. Update voltage duals from grid state
647
- observed_voltages = grid.voltages_vector()
648
- self._voltage_dual.update(observed_voltages)
649
-
650
- voltage_dual_diff = self._voltage_dual.dual_difference() # η = λ̄ − λ
651
-
652
- # 3. Read observed latency from datacenter and update latency duals
653
- dc_state = datacenter.state
654
- missing_replicas = [
655
- ms.model_label for ms in self._models if ms.model_label not in dc_state.active_replicas_by_model
656
- ]
657
- if missing_replicas:
658
- miss = ", ".join(sorted(missing_replicas))
659
- raise RuntimeError(
660
- f"OFOBatchSizeController requires active_replicas_by_model for all models. Missing: {miss}."
661
- )
662
- missing_itl = [ms.model_label for ms in self._models if ms.model_label not in dc_state.observed_itl_s_by_model]
663
- if missing_itl:
664
- miss = ", ".join(sorted(missing_itl))
665
- raise RuntimeError(
666
- f"OFOBatchSizeController requires observed_itl_s_by_model for all models. Missing: {miss}."
667
- )
668
- for ms in self._models:
669
- label = ms.model_label
670
- num_replicas = max(int(dc_state.active_replicas_by_model[label]), 0)
671
- observed_itl = float(dc_state.observed_itl_s_by_model[label])
672
- if num_replicas <= 0:
673
- logger.debug("Model %s has 0 replicas, skipping latency dual update", label)
674
- observed_itl = float("nan")
675
-
676
- deadline = float(self._itl_deadline_by_model[label])
677
- if np.isfinite(observed_itl):
678
- self._latency_dual_by_model[label] = max(
679
- self._latency_dual_by_model[label]
680
- + self._config.latency_dual_step_size * (observed_itl - deadline),
681
- 0.0,
682
- )
683
- else:
684
- self._latency_dual_by_model[label] = max(self._latency_dual_by_model[label], 0.0)
685
-
686
- # 4. Compute replica counts
687
- replica_count_by_model: dict[str, float] = {}
688
- for ms in self._models:
689
- label = ms.model_label
690
- replica_count_by_model[label] = float(dc_state.active_replicas_by_model[label])
691
-
692
- # 5. Primal update -> next batch sizes
693
- batch_next = self._optimizer.step(
694
- voltage_dual_diff=voltage_dual_diff,
695
- sensitivity_matrix=self._sensitivity_matrix,
696
- phase_share_by_model=datacenter.phase_share_by_model,
697
- latency_dual_by_model=self._latency_dual_by_model,
698
- replica_count_by_model=replica_count_by_model,
699
- )
700
-
701
- self._control_step_count += 1
702
- logger.info(
703
- "OFO step %d (t=%.1f s): batch=%s",
704
- self._control_step_count,
705
- clock.time_s,
706
- batch_next,
707
- )
708
- events.emit(
709
- "controller.ofo.step",
710
- {
711
- "batch_size_by_model": batch_next,
712
- "latency_dual_by_model": dict(self._latency_dual_by_model),
713
- },
714
- )
715
- return [SetBatchSize(batch_size_by_model=batch_next)]
716
-
717
-
718
- def _plot_logistic_fits(
719
- by_batch: dict[str, dict[int, list[tuple[float, float, float]]]],
720
- power: dict[str, LogisticModel],
721
- latency_fits: dict[str, LogisticModel],
722
- throughput_fits: dict[str, LogisticModel],
723
- model_labels: list[str],
724
- out_dir: Path,
725
- ) -> None:
726
- """Plot 3x1 stacked logistic fits: power, latency, throughput.
727
-
728
- Scatter dots for measured medians, smooth fitted curves from
729
- LogisticModel parameters. Saves to `out_dir / "logistic_fits.png"`.
730
- """
731
- import matplotlib.pyplot as plt
732
-
733
- metric_specs: list[tuple[str, int, dict[str, LogisticModel], str, str]] = [
734
- ("power", 0, power, "W", "(a) Average GPU power consumption vs batch size"),
735
- ("latency", 1, latency_fits, "s/token", "(b) Average inter-token latency vs batch size"),
736
- ("throughput", 2, throughput_fits, "tokens/s", "(c) Average token throughput vs batch size"),
737
- ]
738
-
739
- fig, axes = plt.subplots(3, 1, figsize=(6.45, 5.2), dpi=300, sharex=True)
740
-
741
- for ax_idx, (ax, (_metric_name, val_idx, fits, ylabel, title)) in enumerate(zip(axes, metric_specs, strict=True)):
742
- xmins: list[float] = []
743
- xmaxs: list[float] = []
744
-
745
- for label in model_labels:
746
- model_by_batch = by_batch.get(label, {})
747
- batches = sorted(model_by_batch.keys())
748
- if not batches:
749
- continue
750
- x = np.log2(np.array(batches, dtype=float).clip(min=1))
751
- if len(x) > 0:
752
- xmins.append(float(np.min(x)))
753
- xmaxs.append(float(np.max(x)))
754
-
755
- if not xmins:
756
- ax.set_title(title, fontsize=12, loc="center")
757
- ax.set_ylabel(ylabel, fontsize=10)
758
- ax.grid(True, alpha=0.25)
759
- continue
760
-
761
- xs = np.linspace(min(xmins), max(xmaxs), 400)
762
-
763
- for label in model_labels:
764
- model_by_batch = by_batch.get(label, {})
765
- batches = sorted(model_by_batch.keys())
766
- if not batches or label not in fits:
767
- continue
768
-
769
- x = np.log2(np.array(batches, dtype=float).clip(min=1))
770
- y = np.array([float(np.median([t[val_idx] for t in model_by_batch[b]])) for b in batches])
771
-
772
- fit = fits[label]
773
- ys_fit = np.array([fit.eval_x(float(xi)) for xi in xs])
774
-
775
- (line,) = ax.plot(xs, ys_fit, lw=1.8, label=label, zorder=2)
776
- ax.scatter(x, y, s=16.0, color=line.get_color(), zorder=3)
777
-
778
- ax.set_title(title, fontsize=12, loc="center")
779
- ax.set_ylabel(ylabel, fontsize=10)
780
- ax.grid(True, alpha=0.25)
781
- ax.tick_params(axis="both", labelsize=10)
782
-
783
- if ax_idx == 2:
784
- ax.legend(frameon=True, fontsize=9, loc="best")
785
-
786
- axes[-1].set_xlabel(r"$\log_2(\mathrm{batch\ size})$", fontsize=10)
787
- fig.tight_layout(pad=0.35, h_pad=0.6)
788
-
789
- save_path = out_dir / "logistic_fits.png"
790
- save_path.parent.mkdir(parents=True, exist_ok=True)
791
- fig.savefig(save_path, bbox_inches="tight", pad_inches=0.02)
792
- plt.close(fig)
793
- logger.info("Saved logistic fits plot to %s", save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/controller/tap_schedule.py DELETED
@@ -1,70 +0,0 @@
1
- """Tap schedule controller: applies pre-defined regulator tap changes at specified times."""
2
-
3
- from __future__ import annotations
4
-
5
- from fractions import Fraction
6
-
7
- from openg2g.clock import SimulationClock
8
- from openg2g.controller.base import Controller
9
- from openg2g.datacenter.base import DatacenterBackend
10
- from openg2g.datacenter.command import DatacenterCommand
11
- from openg2g.events import EventEmitter
12
- from openg2g.grid.base import GridBackend
13
- from openg2g.grid.command import GridCommand, SetTaps
14
- from openg2g.grid.config import TapPosition, TapSchedule
15
-
16
-
17
- class TapScheduleController(Controller[DatacenterBackend, GridBackend]):
18
- """Applies pre-defined tap changes at scheduled times.
19
-
20
- Args:
21
- schedule: Tap schedule built via
22
- [`TapPosition(...).at(t=...) | ...`][openg2g.grid.config.TapSchedule].
23
- dt_s: How often the controller checks the schedule (seconds).
24
- """
25
-
26
- def __init__(self, *, schedule: TapSchedule, dt_s: Fraction = Fraction(1)) -> None:
27
- self._dt_s = dt_s
28
- self._entries = list(schedule)
29
- self._idx = 0
30
-
31
- def reset(self) -> None:
32
- self._idx = 0
33
-
34
- @property
35
- def dt_s(self) -> Fraction:
36
- return self._dt_s
37
-
38
- def step(
39
- self,
40
- clock: SimulationClock,
41
- datacenter: DatacenterBackend,
42
- grid: GridBackend,
43
- events: EventEmitter,
44
- ) -> list[DatacenterCommand | GridCommand]:
45
-
46
- t_now = clock.time_s
47
- merged_a: float | None = None
48
- merged_b: float | None = None
49
- merged_c: float | None = None
50
- any_fired = False
51
-
52
- while self._idx < len(self._entries):
53
- t_ev, pos = self._entries[self._idx]
54
- if float(t_ev) <= t_now + 1e-12:
55
- if pos.a is not None:
56
- merged_a = pos.a
57
- if pos.b is not None:
58
- merged_b = pos.b
59
- if pos.c is not None:
60
- merged_c = pos.c
61
- any_fired = True
62
- self._idx += 1
63
- else:
64
- break
65
-
66
- if any_fired and (merged_a is not None or merged_b is not None or merged_c is not None):
67
- tap = TapPosition(a=merged_a, b=merged_b, c=merged_c)
68
- events.emit("controller.tap_schedule.fired", {"tap_position": tap})
69
- return [SetTaps(tap_position=tap)]
70
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/coordinator.py DELETED
@@ -1,269 +0,0 @@
1
- """Central coordinator: multi-rate simulation loop."""
2
-
3
- from __future__ import annotations
4
-
5
- import logging
6
- import warnings
7
- from collections.abc import Sequence
8
- from dataclasses import dataclass, field
9
- from fractions import Fraction
10
- from typing import Any, Generic
11
-
12
- from openg2g.clock import SimulationClock
13
- from openg2g.common import ThreePhase
14
- from openg2g.controller.base import Controller
15
- from openg2g.datacenter.base import DatacenterBackend, DCStateT
16
- from openg2g.datacenter.command import DatacenterCommand
17
- from openg2g.events import EventEmitter, SimEvent
18
- from openg2g.grid.base import GridBackend, GridStateT, PhaseVoltages
19
- from openg2g.grid.command import GridCommand
20
-
21
- logger = logging.getLogger(__name__)
22
-
23
-
24
- @dataclass
25
- class SimulationLog(Generic[DCStateT, GridStateT]):
26
- """Accumulated simulation data from a coordinator run.
27
-
28
- Generic over the datacenter and grid state types. When constructed
29
- via [`Coordinator.run`][..Coordinator.run], the type parameters are
30
- inferred from the backends, giving typed access to backend-specific
31
- state fields.
32
-
33
- Attributes:
34
- dc_states: Every datacenter state produced by the datacenter.
35
- grid_states: Every grid state produced by the grid.
36
- commands: All commands emitted by controllers.
37
- time_s: Simulation time at each grid step (seconds).
38
- voltage_a_pu: DC-bus voltage phase A at each grid step (pu).
39
- voltage_b_pu: DC-bus voltage phase B at each grid step (pu).
40
- voltage_c_pu: DC-bus voltage phase C at each grid step (pu).
41
- events: Clock-stamped simulation events from all components.
42
- """
43
-
44
- dc_states: list[DCStateT] = field(default_factory=list)
45
- grid_states: list[GridStateT] = field(default_factory=list)
46
- commands: list[DatacenterCommand | GridCommand] = field(default_factory=list)
47
-
48
- time_s: list[float] = field(default_factory=list)
49
- voltage_a_pu: list[float] = field(default_factory=list)
50
- voltage_b_pu: list[float] = field(default_factory=list)
51
- voltage_c_pu: list[float] = field(default_factory=list)
52
-
53
- events: list[SimEvent] = field(default_factory=list)
54
-
55
- def record_datacenter(self, state: DCStateT) -> None:
56
- """Append a datacenter state snapshot."""
57
- self.dc_states.append(state)
58
-
59
- def record_grid(self, state: GridStateT, *, dc_bus: str) -> None:
60
- """Append a grid state snapshot and extract DC bus voltages."""
61
- self.grid_states.append(state)
62
- self.time_s.append(state.time_s)
63
-
64
- v_dc = (
65
- state.voltages[dc_bus]
66
- if dc_bus in state.voltages
67
- else PhaseVoltages(a=float("nan"), b=float("nan"), c=float("nan"))
68
- )
69
- self.voltage_a_pu.append(v_dc.a)
70
- self.voltage_b_pu.append(v_dc.b)
71
- self.voltage_c_pu.append(v_dc.c)
72
-
73
- def record_commands(self, commands: list[DatacenterCommand | GridCommand]) -> None:
74
- """Append control commands issued during a tick."""
75
- self.commands.extend(commands)
76
-
77
- def emit(self, event: SimEvent) -> None:
78
- """Event sink entrypoint for component-originated events."""
79
- self.events.append(event)
80
-
81
-
82
- def _gcd_fraction(a: Fraction, b: Fraction) -> Fraction:
83
- """GCD of two positive Fractions using Euclidean algorithm."""
84
- a, b = abs(a), abs(b)
85
- while b:
86
- a, b = b, a % b
87
- return a
88
-
89
-
90
- class Coordinator(Generic[DCStateT, GridStateT]):
91
- """Multi-rate simulation coordinator.
92
-
93
- Orchestrates datacenter, grid, and controller components at their
94
- respective rates. The base tick is the GCD of all component
95
- periods.
96
-
97
- Generic over datacenter and grid state types. The type parameters
98
- are inferred from the backends and propagated to
99
- [`SimulationLog`][..SimulationLog].
100
-
101
- Args:
102
- datacenter: Datacenter backend (offline or online).
103
- grid: Grid simulator backend.
104
- controllers: List of controllers, applied in order each tick.
105
- total_duration_s: Total simulation duration (integer seconds).
106
- dc_bus: Bus name for DC voltage logging.
107
- live: If True, synchronize with wall-clock time.
108
- """
109
-
110
- def __init__(
111
- self,
112
- datacenter: DatacenterBackend[DCStateT],
113
- grid: GridBackend[GridStateT],
114
- controllers: Sequence[Controller[Any, Any]],
115
- total_duration_s: int,
116
- dc_bus: str,
117
- live: bool = False,
118
- ) -> None:
119
- self.datacenter = datacenter
120
- self.grid = grid
121
- self.controllers = list(controllers)
122
- self.total_duration_s = int(total_duration_s)
123
- self.dc_bus = str(dc_bus)
124
-
125
- # Compute tick as GCD of all component periods
126
- periods = [datacenter.dt_s, grid.dt_s] + [c.dt_s for c in controllers]
127
- tick = periods[0]
128
- for p in periods[1:]:
129
- tick = _gcd_fraction(tick, p)
130
- logger.info("Coordinator will run with tick %f s", float(tick))
131
-
132
- # Warn about potentially problematic dt configurations
133
- if grid.dt_s < datacenter.dt_s:
134
- warnings.warn(
135
- f"dt_grid ({grid.dt_s}) < dt_dc ({datacenter.dt_s}): "
136
- f"grid steps between DC steps will reuse the most recent DC power.",
137
- stacklevel=2,
138
- )
139
- for ctrl in controllers:
140
- if ctrl.dt_s < grid.dt_s:
141
- warnings.warn(
142
- f"Controller {ctrl.__class__.__name__} dt_s ({ctrl.dt_s}) "
143
- f"< dt_grid ({grid.dt_s}): controller may read stale voltages.",
144
- stacklevel=2,
145
- )
146
- n_ticks_estimate = Fraction(self.total_duration_s) / tick
147
- if n_ticks_estimate > 10_000_000:
148
- warnings.warn(
149
- f"Simulation will run {int(n_ticks_estimate)} ticks. This may be slow. Consider coarser time steps.",
150
- stacklevel=2,
151
- )
152
-
153
- self.clock = SimulationClock(tick_s=tick, live=live)
154
-
155
- def reset(self) -> None:
156
- """Reset coordinator and all sub-components for a fresh run."""
157
- self.clock.reset()
158
- self.datacenter.do_reset()
159
- self.grid.do_reset()
160
- for ctrl in self.controllers:
161
- ctrl.reset()
162
-
163
- def start(self) -> None:
164
- """Acquire resources on all sub-components."""
165
- self.datacenter.start()
166
- self.grid.start()
167
- for ctrl in self.controllers:
168
- ctrl.start()
169
-
170
- def stop(self) -> None:
171
- """Release resources on all sub-components (LIFO order)."""
172
- for ctrl in reversed(self.controllers):
173
- ctrl.stop()
174
- self.grid.stop()
175
- self.datacenter.stop()
176
-
177
- def _validate_controller_compatibility(self) -> None:
178
- for ctrl in self.controllers:
179
- sig = ctrl.__class__.compatibility_signature()
180
-
181
- dc_types = ctrl.compatible_datacenter_types()
182
- try:
183
- dc_ok = isinstance(self.datacenter, dc_types)
184
- except TypeError:
185
- continue
186
- if not dc_ok:
187
- expected = " | ".join(t.__name__ for t in dc_types)
188
- got = type(self.datacenter).__name__
189
- raise TypeError(f"{ctrl.__class__.__name__} ({sig}) requires datacenter type {expected}, got {got}.")
190
-
191
- grid_types = ctrl.compatible_grid_types()
192
- try:
193
- grid_ok = isinstance(self.grid, grid_types)
194
- except TypeError:
195
- continue
196
- if not grid_ok:
197
- expected = " | ".join(t.__name__ for t in grid_types)
198
- got = type(self.grid).__name__
199
- raise TypeError(f"{ctrl.__class__.__name__} ({sig}) requires grid type {expected}, got {got}.")
200
-
201
- def run(self) -> SimulationLog[DCStateT, GridStateT]:
202
- """Run the full simulation and return the log."""
203
- log: SimulationLog[DCStateT, GridStateT] = SimulationLog()
204
- dc_events = EventEmitter(self.clock, log, "datacenter")
205
- grid_events = EventEmitter(self.clock, log, "grid")
206
- controller_events = EventEmitter(self.clock, log, "controller")
207
-
208
- self._validate_controller_compatibility()
209
-
210
- self.reset()
211
- self.start()
212
-
213
- dc_buffer: list[ThreePhase] = []
214
-
215
- ratio = Fraction(self.total_duration_s) / self.clock.tick_s
216
- if ratio.denominator != 1:
217
- raise ValueError(
218
- f"total_duration_s ({self.total_duration_s}) is not an exact multiple of tick_s ({self.clock.tick_s})"
219
- )
220
- n_ticks = int(ratio)
221
-
222
- logger.info(
223
- "Starting simulation: %d s, tick=%s s, %d ticks, dt_dc=%s s, dt_grid=%s s, %d controller(s)",
224
- self.total_duration_s,
225
- self.clock.tick_s,
226
- n_ticks,
227
- self.datacenter.dt_s,
228
- self.grid.dt_s,
229
- len(self.controllers),
230
- )
231
-
232
- try:
233
- for _ in range(n_ticks):
234
- # 1. Datacenter step (if due)
235
- if self.clock.is_due(self.datacenter.dt_s):
236
- dc_state = self.datacenter.do_step(self.clock, dc_events)
237
- dc_buffer.append(dc_state.power_w)
238
- log.record_datacenter(dc_state)
239
-
240
- # 2. Grid step (if due). Pass full sub-trace since last grid step.
241
- if self.clock.is_due(self.grid.dt_s):
242
- grid_state = self.grid.do_step(self.clock, list(dc_buffer), grid_events)
243
- dc_buffer.clear()
244
- log.record_grid(grid_state, dc_bus=self.dc_bus)
245
-
246
- # 3. Controllers (if due). In order, actions applied immediately.
247
- for ctrl in self.controllers:
248
- if self.clock.is_due(ctrl.dt_s):
249
- commands = ctrl.step(self.clock, self.datacenter, self.grid, controller_events)
250
- for command in commands:
251
- if isinstance(command, DatacenterCommand):
252
- self.datacenter.apply_control(command, dc_events)
253
- elif isinstance(command, GridCommand):
254
- self.grid.apply_control(command, grid_events)
255
- else:
256
- raise ValueError(f"Unsupported command type: {type(command).__name__}")
257
- log.record_commands(commands)
258
-
259
- self.clock.advance()
260
- finally:
261
- self.stop()
262
-
263
- logger.info(
264
- "Simulation complete: %d grid steps, %d DC steps, %d commands",
265
- len(log.grid_states),
266
- len(log.dc_states),
267
- len(log.commands),
268
- )
269
- return log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/datacenter/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Datacenter backends for openg2g."""
 
 
openg2g/datacenter/base.py DELETED
@@ -1,183 +0,0 @@
1
- """Abstract base class for datacenter backends and base state types."""
2
-
3
- from __future__ import annotations
4
-
5
- from abc import ABC, abstractmethod
6
- from dataclasses import dataclass, field
7
- from fractions import Fraction
8
- from typing import Generic, TypeVar, final
9
-
10
- import numpy as np
11
-
12
- from openg2g.clock import SimulationClock
13
- from openg2g.common import ThreePhase
14
- from openg2g.datacenter.command import DatacenterCommand
15
- from openg2g.events import EventEmitter
16
-
17
-
18
- @dataclass(frozen=True)
19
- class DatacenterState:
20
- """State emitted by a datacenter backend each timestep.
21
-
22
- Contains only universally applicable fields. LLM-inference-specific
23
- fields (batch sizes, replicas, latency) live on child classes like
24
- [`LLMDatacenterState`][..LLMDatacenterState].
25
-
26
- Attributes:
27
- time_s: Simulation time in seconds.
28
- power_w: Three-phase power in watts.
29
- """
30
-
31
- time_s: float
32
- power_w: ThreePhase
33
-
34
-
35
- @dataclass(frozen=True)
36
- class LLMDatacenterState(DatacenterState):
37
- """State from a datacenter serving LLM workloads.
38
-
39
- Extends [`DatacenterState`][..DatacenterState] with per-model batch
40
- size, replica count, and observed inter-token latency fields used
41
- by LLM controllers.
42
-
43
- Attributes:
44
- batch_size_by_model: Current batch size per model label.
45
- active_replicas_by_model: Number of active replicas per model.
46
- observed_itl_s_by_model: Observed average inter-token latency
47
- (seconds) per model. `NaN` if unavailable.
48
- """
49
-
50
- batch_size_by_model: dict[str, int] = field(default_factory=dict)
51
- active_replicas_by_model: dict[str, int] = field(default_factory=dict)
52
- observed_itl_s_by_model: dict[str, float] = field(default_factory=dict)
53
-
54
-
55
- DCStateT = TypeVar("DCStateT", bound=DatacenterState)
56
-
57
-
58
- class DatacenterBackend(Generic[DCStateT], ABC):
59
- """Interface for datacenter power simulation backends."""
60
-
61
- _INIT_SENTINEL = object()
62
-
63
- def __init__(self) -> None:
64
- self._state: DCStateT | None = None
65
- self._history: list[DCStateT] = []
66
- self._dc_base_init = DatacenterBackend._INIT_SENTINEL
67
-
68
- def _check_base_init(self) -> None:
69
- if getattr(self, "_dc_base_init", None) is not DatacenterBackend._INIT_SENTINEL:
70
- raise TypeError(f"{type(self).__name__}.__init__ must call super().__init__() ")
71
-
72
- @property
73
- @abstractmethod
74
- def dt_s(self) -> Fraction:
75
- """Native timestep as a Fraction (seconds)."""
76
-
77
- @final
78
- @property
79
- def state(self) -> DCStateT:
80
- """Latest emitted state.
81
-
82
- Raises:
83
- RuntimeError: If accessed before the first `step()` call.
84
- """
85
- self._check_base_init()
86
- if self._state is None:
87
- raise RuntimeError(f"{type(self).__name__}.state accessed before first step().")
88
- return self._state
89
-
90
- @final
91
- def history(self, n: int | None = None) -> list[DCStateT]:
92
- """Return emitted state history (all, or latest `n`)."""
93
- self._check_base_init()
94
- if n is None:
95
- return list(self._history)
96
- if n <= 0:
97
- return []
98
- return list(self._history[-int(n) :])
99
-
100
- @final
101
- def do_step(self, clock: SimulationClock, events: EventEmitter) -> DCStateT:
102
- """Call `step`, record the state, and return it.
103
-
104
- Called by the coordinator. Subclasses should not override this.
105
- """
106
- self._check_base_init()
107
- state = self.step(clock, events)
108
- self._state = state
109
- self._history.append(state)
110
- return state
111
-
112
- @abstractmethod
113
- def step(self, clock: SimulationClock, events: EventEmitter) -> DCStateT:
114
- """Advance one native timestep. Return state for this step."""
115
-
116
- @abstractmethod
117
- def apply_control(self, command: DatacenterCommand, events: EventEmitter) -> None:
118
- """Apply one command. Takes effect on next step() call."""
119
-
120
- @final
121
- def do_reset(self) -> None:
122
- """Clear history and call `reset`.
123
-
124
- Called by the coordinator. Subclasses should not override this.
125
- """
126
- self._check_base_init()
127
- self._state = None
128
- self._history.clear()
129
- self.reset()
130
-
131
- @abstractmethod
132
- def reset(self) -> None:
133
- """Reset simulation state to initial conditions.
134
-
135
- Called by the coordinator (via `do_reset`) before each
136
- [`start`][..start]. Must clear all simulation state: counters,
137
- RNG seeds, cached values. Configuration (dt_s, models,
138
- templates) is not affected. History is cleared automatically
139
- by `do_reset`.
140
-
141
- Abstract so every implementation explicitly enumerates its state.
142
- A forgotten field is a bug -- not clearing it silently corrupts
143
- the second run.
144
- """
145
-
146
- def start(self) -> None:
147
- """Acquire per-run resources (threads, solver circuits).
148
-
149
- Called after [`reset`][..reset], before the simulation loop.
150
- Override for backends that need resource acquisition (e.g.,
151
- [`OpenDSSGrid`][openg2g.grid.opendss.OpenDSSGrid] compiles its
152
- DSS circuit here). No-op by default because most offline
153
- components have no resources to acquire.
154
- """
155
-
156
- def stop(self) -> None:
157
- """Release per-run resources. Simulation state is preserved.
158
-
159
- Called after the simulation loop in LIFO order. Override for
160
- backends that acquired resources in [`start`][..start]. No-op
161
- by default.
162
- """
163
-
164
-
165
- class LLMBatchSizeControlledDatacenter(DatacenterBackend[DCStateT]):
166
- """Datacenter that serves LLM workloads and supports batch-size control.
167
-
168
- Marker layer between [`DatacenterBackend`][..DatacenterBackend] and
169
- concrete implementations. Controllers that issue
170
- [`SetBatchSize`][openg2g.datacenter.command.SetBatchSize] commands or read
171
- `active_replicas_by_model` / `observed_itl_s_by_model`
172
- from state should bind their generic to this class.
173
- """
174
-
175
- @property
176
- def phase_share_by_model(self) -> dict[str, np.ndarray]:
177
- """Per-model phase share vectors `[frac_A, frac_B, frac_C]`.
178
-
179
- Returns an empty dict by default. Consumers treat missing keys
180
- as uniform `[1/3, 1/3, 1/3]`. Override in subclasses that know
181
- actual server-to-phase placement.
182
- """
183
- return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/datacenter/command.py DELETED
@@ -1,31 +0,0 @@
1
- """Command types targeting datacenter backends."""
2
-
3
- from __future__ import annotations
4
-
5
- from dataclasses import dataclass, field
6
-
7
-
8
- class DatacenterCommand:
9
- """Base for commands targeting the datacenter backend.
10
-
11
- Subclass this for each concrete datacenter command kind.
12
- The coordinator routes commands to backends based on this type hierarchy.
13
- """
14
-
15
- def __init__(self) -> None:
16
- if type(self) is DatacenterCommand:
17
- raise TypeError("DatacenterCommand cannot be instantiated directly; subclass it.")
18
-
19
-
20
- @dataclass(frozen=True)
21
- class SetBatchSize(DatacenterCommand):
22
- """Set batch sizes for one or more models.
23
-
24
- Attributes:
25
- batch_size_by_model: Mapping of model label to target batch size.
26
- ramp_up_rate_by_model: Per-model requests/second ramp-up rate.
27
- Models not present get immediate changes (rate 0).
28
- """
29
-
30
- batch_size_by_model: dict[str, int]
31
- ramp_up_rate_by_model: dict[str, float] = field(default_factory=dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/datacenter/config.py DELETED
@@ -1,342 +0,0 @@
1
- """Datacenter facility and workload configuration."""
2
-
3
- from __future__ import annotations
4
-
5
- from collections.abc import Iterator
6
- from dataclasses import dataclass
7
-
8
- import numpy as np
9
- from pydantic import BaseModel, ConfigDict, model_validator
10
-
11
- from openg2g.datacenter.workloads.training import TrainingTrace
12
-
13
-
14
- class InferenceModelSpec(BaseModel):
15
- """Specification for one LLM model served in the datacenter.
16
-
17
- Attributes:
18
- model_label: Human-readable model identifier (e.g. `"Llama-3.1-70B"`).
19
- model_id: HuggingFace model ID (e.g. `"meta-llama/Llama-3.1-70B-Instruct"`).
20
- Used for benchmark data lookups and online API model fields.
21
- num_replicas: Total number of replicas of this model across the datacenter.
22
- gpus_per_replica: GPUs allocated to each replica (determines model
23
- parallelism and per-replica power draw).
24
- initial_batch_size: Initial batch size for this model.
25
- itl_deadline_s: Per-model inter-token latency deadline for the OFO
26
- latency dual (seconds).
27
- feasible_batch_sizes: Allowed batch sizes. Used by the OFO
28
- controller for discretizing continuous batch-size updates
29
- and by the online datacenter for load-generator sizing.
30
- Defaults to `(initial_batch_size,)`.
31
- """
32
-
33
- model_config = ConfigDict(frozen=True)
34
-
35
- model_label: str
36
- model_id: str = ""
37
- num_replicas: int
38
- gpus_per_replica: int
39
- initial_batch_size: int
40
- itl_deadline_s: float
41
- feasible_batch_sizes: tuple[int, ...] = ()
42
-
43
- @model_validator(mode="after")
44
- def _validate(self) -> InferenceModelSpec:
45
- if not self.feasible_batch_sizes:
46
- object.__setattr__(self, "feasible_batch_sizes", (self.initial_batch_size,))
47
- elif self.initial_batch_size not in self.feasible_batch_sizes:
48
- raise ValueError(
49
- f"initial_batch_size ({self.initial_batch_size}) must be in "
50
- f"feasible_batch_sizes ({self.feasible_batch_sizes})."
51
- )
52
- if self.num_replicas < 0:
53
- raise ValueError(f"num_replicas must be >= 0, got {self.num_replicas}.")
54
- if self.gpus_per_replica < 1:
55
- raise ValueError(f"gpus_per_replica must be >= 1, got {self.gpus_per_replica}.")
56
- if self.initial_batch_size <= 0:
57
- raise ValueError(f"initial_batch_size must be > 0, got {self.initial_batch_size}.")
58
- if self.itl_deadline_s <= 0:
59
- raise ValueError(f"itl_deadline_s must be > 0, got {self.itl_deadline_s}.")
60
- return self
61
-
62
-
63
- class TrainingRun:
64
- """Training workload parameters.
65
-
66
- The trace is eagerly rescaled so its peak matches `target_peak_W_per_gpu`.
67
- Use `eval_power` to evaluate total training power at a given simulation time.
68
-
69
- Combine with [`at`][.at] and `|` to build a [`TrainingSchedule`][..TrainingSchedule]:
70
-
71
- ```python
72
- schedule = (
73
- TrainingRun(n_gpus=2400, trace=trace_a).at(t_start=1000, t_end=2000)
74
- | TrainingRun(n_gpus=1200, trace=trace_b).at(t_start=2500, t_end=3500)
75
- )
76
- ```
77
-
78
- Attributes:
79
- n_gpus: Number of GPUs running the training workload.
80
- trace: Single-GPU [`TrainingTrace`][openg2g.datacenter.workloads.training.TrainingTrace].
81
- target_peak_W_per_gpu: The trace is rescaled so its peak equals this value.
82
- """
83
-
84
- __slots__ = ("_period", "_rescaled_power", "_trace_time", "n_gpus", "target_peak_W_per_gpu", "trace")
85
-
86
- def __init__(self, *, n_gpus: int, trace: TrainingTrace, target_peak_W_per_gpu: float = 400.0) -> None:
87
- if n_gpus <= 0:
88
- raise ValueError(f"TrainingRun n_gpus must be > 0, got {n_gpus}.")
89
- self.n_gpus = n_gpus
90
- self.trace = trace
91
- self.target_peak_W_per_gpu = target_peak_W_per_gpu
92
-
93
- t = np.asarray(trace.t_s, float)
94
- p = np.asarray(trace.power_w, float)
95
- t = t - t[0]
96
- period = float(t[-1] - t[0])
97
- if period <= 0:
98
- raise ValueError("Training trace time span must be positive.")
99
- peak = float(np.max(p))
100
- if peak <= 0:
101
- raise ValueError("Training trace has non-positive peak; cannot scale.")
102
- self._rescaled_power = p * (target_peak_W_per_gpu / peak)
103
- self._trace_time = t
104
- self._period = period
105
-
106
- def eval_power(self, t: float, t_start: float, t_end: float) -> float:
107
- """Evaluate total training power at simulation time `t`.
108
-
109
- Returns zero if `t` is outside `[t_start, t_end]`.
110
-
111
- Args:
112
- t: Global simulation time (seconds).
113
- t_start: Time when training becomes active (seconds).
114
- t_end: Time when training stops (seconds).
115
-
116
- Returns:
117
- Total training power (W) across all `n_gpus` GPUs.
118
- """
119
- if t < t_start or t > t_end:
120
- return 0.0
121
- t_local = t - t_start
122
- t_mod = t_local % self._period
123
- p_1gpu = float(np.interp(t_mod, self._trace_time, self._rescaled_power))
124
- return p_1gpu * self.n_gpus
125
-
126
- def at(self, t_start: float, t_end: float) -> TrainingSchedule:
127
- """Schedule this training run over `[t_start, t_end]`.
128
-
129
- Args:
130
- t_start: Global simulation time when training becomes active (seconds).
131
- t_end: Global simulation time when training stops (seconds).
132
-
133
- Returns:
134
- A single-entry [`TrainingSchedule`][...TrainingSchedule].
135
- """
136
- if t_end < t_start:
137
- raise ValueError(f"t_end ({t_end}) must be >= t_start ({t_start}).")
138
- return TrainingSchedule(((self, float(t_start), float(t_end)),))
139
-
140
-
141
- class TrainingSchedule:
142
- """Ordered collection of [`TrainingRun`][..TrainingRun] objects scheduled
143
- over time windows.
144
-
145
- Each entry is a `(TrainingRun, t_start, t_end)` tuple. Entries are
146
- sorted by `t_start`.
147
-
148
- Built with [`TrainingRun.at`][..TrainingRun.at] and `|`.
149
-
150
- Example:
151
-
152
- ```python
153
- schedule = (
154
- TrainingRun(n_gpus=2400, trace=trace_a).at(t_start=1000, t_end=2000)
155
- | TrainingRun(n_gpus=1200, trace=trace_b).at(t_start=2500, t_end=3500)
156
- )
157
- ```
158
- """
159
-
160
- __slots__ = ("_entries",)
161
-
162
- def __init__(self, entries: tuple[tuple[TrainingRun, float, float], ...] = ()) -> None:
163
- self._entries = tuple(sorted(entries, key=lambda e: e[1]))
164
-
165
- def __or__(self, other: TrainingSchedule) -> TrainingSchedule:
166
- return TrainingSchedule((*self._entries, *other._entries))
167
-
168
- def __iter__(self) -> Iterator[tuple[TrainingRun, float, float]]:
169
- return iter(self._entries)
170
-
171
- def __len__(self) -> int:
172
- return len(self._entries)
173
-
174
- def __bool__(self) -> bool:
175
- return bool(self._entries)
176
-
177
- def __repr__(self) -> str:
178
- parts = [f"TrainingRun(n_gpus={r.n_gpus}).at(t_start={s}, t_end={e})" for r, s, e in self._entries]
179
- return " | ".join(parts)
180
-
181
-
182
- @dataclass(frozen=True)
183
- class InferenceRamp:
184
- """Inference server ramp parameters.
185
-
186
- Transitions the active inference server fraction to `target`. Combine with
187
- [`at`][.at] and `|` to build an [`InferenceRampSchedule`][..InferenceRampSchedule]:
188
-
189
- ```python
190
- ramps = (
191
- InferenceRamp(target=0.2).at(t_start=2500, t_end=3000)
192
- | InferenceRamp(target=1.0).at(t_start=3200, t_end=3400)
193
- )
194
- ```
195
-
196
- Attributes:
197
- target: Target active-server fraction after the ramp (0.0--1.0).
198
- """
199
-
200
- target: float
201
-
202
- def __post_init__(self) -> None:
203
- if not (0.0 <= self.target <= 1.0):
204
- raise ValueError(f"InferenceRamp target must be in [0.0, 1.0], got {self.target}.")
205
-
206
- def at(self, t_start: float, t_end: float) -> InferenceRampSchedule:
207
- """Schedule this ramp over `[t_start, t_end]`.
208
-
209
- Args:
210
- t_start: Global simulation time when the ramp begins (seconds).
211
- t_end: Global simulation time when the ramp ends (seconds).
212
-
213
- Returns:
214
- A single-entry [`InferenceRampSchedule`][...InferenceRampSchedule].
215
- """
216
- if t_end < t_start:
217
- raise ValueError(f"t_end ({t_end}) must be >= t_start ({t_start}).")
218
- return InferenceRampSchedule(((self, float(t_start), float(t_end)),))
219
-
220
-
221
- class InferenceRampSchedule:
222
- """Ordered collection of [`InferenceRamp`][..InferenceRamp] events.
223
-
224
- Each entry is an `(InferenceRamp, t_start, t_end)` tuple. Entries are
225
- sorted by `t_start`.
226
-
227
- Built with [`InferenceRamp.at`][..InferenceRamp.at] and `|`.
228
-
229
- Semantics: before the first ramp, fraction = 1.0. During each
230
- `[t_start, t_end]` window, the fraction linearly interpolates from
231
- the previous level to `target`. Between ramps, the fraction holds
232
- at the last target.
233
-
234
- An empty schedule means all servers are active (fraction = 1.0)
235
- at all times.
236
-
237
- Example:
238
-
239
- ```python
240
- ramps = (
241
- InferenceRamp(target=0.2).at(t_start=2500, t_end=3000)
242
- | InferenceRamp(target=1.0).at(t_start=3200, t_end=3400)
243
- )
244
- ```
245
- """
246
-
247
- __slots__ = ("_entries",)
248
-
249
- def __init__(self, entries: tuple[tuple[InferenceRamp, float, float], ...] = ()) -> None:
250
- self._entries = tuple(sorted(entries, key=lambda e: e[1]))
251
-
252
- def __or__(self, other: InferenceRampSchedule) -> InferenceRampSchedule:
253
- return InferenceRampSchedule((*self._entries, *other._entries))
254
-
255
- def __iter__(self) -> Iterator[tuple[InferenceRamp, float, float]]:
256
- return iter(self._entries)
257
-
258
- def __len__(self) -> int:
259
- return len(self._entries)
260
-
261
- def __bool__(self) -> bool:
262
- return bool(self._entries)
263
-
264
- def __repr__(self) -> str:
265
- parts = [f"InferenceRamp(target={r.target}).at(t_start={s}, t_end={e})" for r, s, e in self._entries]
266
- return " | ".join(parts)
267
-
268
- def fraction_at(self, t: float | np.ndarray) -> float | np.ndarray:
269
- """Evaluate the active inference server fraction at time(s) *t*.
270
-
271
- Piecewise-linear interpolation between ramp events.
272
- Before the first ramp, fraction = 1.0.
273
-
274
- Args:
275
- t: Scalar or array of global simulation times (seconds).
276
-
277
- Returns:
278
- Active-server fraction(s), same shape as *t*.
279
- """
280
- if isinstance(t, np.ndarray):
281
- return self._fraction_array(t)
282
- return float(self._fraction_scalar(float(t)))
283
-
284
- def _fraction_scalar(self, t: float) -> float:
285
- level = 1.0
286
- for ramp, t_start, t_end in self._entries:
287
- if t < t_start:
288
- return level
289
- if t <= t_end:
290
- if t_end == t_start:
291
- return ramp.target
292
- alpha = (t - t_start) / (t_end - t_start)
293
- return level + (ramp.target - level) * alpha
294
- level = ramp.target
295
- return level
296
-
297
- def _fraction_array(self, t: np.ndarray) -> np.ndarray:
298
- vfunc = np.vectorize(self._fraction_scalar, otypes=[float])
299
- return vfunc(t)
300
-
301
-
302
- class DatacenterConfig(BaseModel):
303
- """Physical datacenter facility configuration.
304
-
305
- Attributes:
306
- gpus_per_server: Number of GPUs per physical server rack.
307
- base_kw_per_phase: Constant base load per phase (kW).
308
- power_factor: Power factor of the datacenter loads (lagging).
309
- """
310
-
311
- model_config = ConfigDict(frozen=True)
312
-
313
- gpus_per_server: int = 8
314
- base_kw_per_phase: float = 0.0
315
- power_factor: float = 0.95
316
-
317
- @model_validator(mode="after")
318
- def _validate(self) -> DatacenterConfig:
319
- if self.gpus_per_server < 1:
320
- raise ValueError(f"gpus_per_server must be >= 1, got {self.gpus_per_server}.")
321
- if not (0.0 < self.power_factor <= 1.0):
322
- raise ValueError(f"power_factor must be in (0, 1], got {self.power_factor}.")
323
- return self
324
-
325
-
326
- class PowerAugmentationConfig(BaseModel):
327
- """Power augmentation settings for virtual server scaling.
328
-
329
- Controls per-server amplitude jitter and additive noise applied during
330
- power augmentation.
331
-
332
- Attributes:
333
- amplitude_scale_range: `(low, high)` range for per-server amplitude
334
- scaling. Each virtual server draws a uniform multiplier from this range.
335
- noise_fraction: Gaussian noise standard deviation as a fraction of
336
- per-server power.
337
- """
338
-
339
- model_config = ConfigDict(frozen=True)
340
-
341
- amplitude_scale_range: tuple[float, float] = (1.0, 1.0)
342
- noise_fraction: float = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/datacenter/layout.py DELETED
@@ -1,126 +0,0 @@
1
- """Server layout and activation policy primitives.
2
-
3
- Provides the topology and activation-policy building blocks used by
4
- datacenter backends. Power augmentation (scaling per-GPU power to
5
- three-phase datacenter power) lives in
6
- `openg2g.datacenter.workloads.inference`.
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- from abc import ABC, abstractmethod
12
- from dataclasses import dataclass
13
-
14
- import numpy as np
15
-
16
- from openg2g.datacenter.config import InferenceRampSchedule
17
-
18
-
19
- class ActivationPolicy(ABC):
20
- """Per-model activation policy that answers "which servers are active?"
21
-
22
- Subclass to implement custom activation logic. The datacenter creates
23
- one policy per model and passes it to
24
- [`InferencePowerAugmenter`][openg2g.datacenter.workloads.inference.InferencePowerAugmenter].
25
- """
26
-
27
- @abstractmethod
28
- def active_mask(self, t: float) -> np.ndarray:
29
- """Boolean mask of active servers at time *t*.
30
-
31
- Returns:
32
- Array of shape `(num_servers,)` with `True` for active servers.
33
- """
34
-
35
- def active_indices(self, t: float) -> np.ndarray:
36
- """Indices of active servers at time *t*.
37
-
38
- The default implementation returns indices in ascending order
39
- via `np.where(`[`active_mask`][..active_mask]`(t))`. Subclasses
40
- may override to return
41
- indices in a specific order (e.g., priority order) to control
42
- floating-point summation order in the datacenter.
43
-
44
- Returns:
45
- 1-D int array of active server indices.
46
- """
47
- return np.where(self.active_mask(t))[0]
48
-
49
-
50
- class RampActivationPolicy(ActivationPolicy):
51
- """Activate servers by fixed random priority, following an
52
- [`InferenceRampSchedule`][openg2g.datacenter.config.InferenceRampSchedule].
53
-
54
- At time *t*, the top-*k* servers (by random priority) are active,
55
- where `k = round(schedule.fraction_at(t) * num_servers)`.
56
-
57
- This is the default policy used by
58
- [`OfflineDatacenter`][openg2g.datacenter.offline.OfflineDatacenter].
59
-
60
- Args:
61
- schedule: Temporal ramp schedule mapping time to active-server fraction.
62
- num_servers: Number of physical servers for this model.
63
- rng: RNG for randomizing priority ordering. Consumed once at
64
- construction time.
65
- """
66
-
67
- __slots__ = ("_n", "_priority", "_schedule")
68
-
69
- def __init__(
70
- self,
71
- schedule: InferenceRampSchedule,
72
- num_servers: int,
73
- rng: np.random.Generator,
74
- ) -> None:
75
- self._schedule = schedule
76
- self._n = num_servers
77
- priority = np.arange(num_servers, dtype=int)
78
- rng.shuffle(priority)
79
- self._priority = priority
80
-
81
- def active_mask(self, t: float) -> np.ndarray:
82
- frac = self._schedule.fraction_at(t)
83
- k = max(0, min(self._n, int(round(float(frac) * self._n))))
84
- mask = np.zeros(self._n, dtype=bool)
85
- mask[self._priority[:k]] = True
86
- return mask
87
-
88
- def active_indices(self, t: float) -> np.ndarray:
89
- """Return active server indices in priority order."""
90
- frac = self._schedule.fraction_at(t)
91
- k = max(0, min(self._n, int(round(float(frac) * self._n))))
92
- return self._priority[:k].copy()
93
-
94
-
95
- @dataclass
96
- class ServerLayout:
97
- """Per-model server layout describing how GPUs are organized.
98
-
99
- This describes the physical topology only. Activation policies (which
100
- servers are on/off at a given time) are managed separately by the
101
- datacenter and passed to
102
- [`InferencePowerAugmenter`][openg2g.datacenter.workloads.inference.InferencePowerAugmenter]
103
- alongside layouts.
104
-
105
- Attributes:
106
- num_servers: Number of physical servers for this model.
107
- total_gpus: Total GPU count across all servers.
108
- gpus_per_replica: GPUs per model replica.
109
- gpus_per_server_list: GPU count per server (last may be partial).
110
- phase_list: Phase assignment per server (0=A, 1=B, 2=C).
111
- stagger_offsets: Per-server offsets for desynchronization. In offline
112
- mode these are integer indices into a power template; in online
113
- mode they can be float time offsets into a rolling buffer.
114
- amplitude_scales: Per-server power multiplier for inter-server variation.
115
- noise_fraction: Gaussian noise standard deviation as a fraction of
116
- per-server power.
117
- """
118
-
119
- num_servers: int
120
- total_gpus: int
121
- gpus_per_replica: int
122
- gpus_per_server_list: np.ndarray
123
- phase_list: np.ndarray
124
- stagger_offsets: np.ndarray
125
- amplitude_scales: np.ndarray
126
- noise_fraction: float
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/datacenter/offline.py DELETED
@@ -1,320 +0,0 @@
1
- """Offline (trace-based) datacenter backend."""
2
-
3
- from __future__ import annotations
4
-
5
- import functools
6
- import logging
7
- import math
8
- from dataclasses import dataclass, field
9
- from fractions import Fraction
10
-
11
- import numpy as np
12
-
13
- from openg2g.clock import SimulationClock
14
- from openg2g.common import ThreePhase
15
- from openg2g.datacenter.base import LLMBatchSizeControlledDatacenter, LLMDatacenterState
16
- from openg2g.datacenter.command import DatacenterCommand, SetBatchSize
17
- from openg2g.datacenter.config import (
18
- DatacenterConfig,
19
- InferenceRampSchedule,
20
- PowerAugmentationConfig,
21
- TrainingSchedule,
22
- )
23
- from openg2g.datacenter.layout import (
24
- ActivationPolicy,
25
- RampActivationPolicy,
26
- ServerLayout,
27
- )
28
- from openg2g.datacenter.workloads.inference import InferenceData, InferencePowerAugmenter
29
- from openg2g.events import EventEmitter
30
- from openg2g.utils import split_integer_evenly
31
-
32
- logger = logging.getLogger(__name__)
33
-
34
-
35
- @dataclass(frozen=True)
36
- class OfflineDatacenterState(LLMDatacenterState):
37
- """Extended state from the offline (trace-based) backend.
38
-
39
- Adds per-model power breakdown to
40
- [`LLMDatacenterState`][openg2g.datacenter.base.LLMDatacenterState].
41
- """
42
-
43
- power_by_model_w: dict[str, float] = field(default_factory=dict)
44
-
45
-
46
- @dataclass
47
- class OfflineWorkload:
48
- """Complete offline simulation workload.
49
-
50
- Bundles inference data with optional training overlays and inference
51
- server ramp events.
52
-
53
- Attributes:
54
- inference_data: LLM inference workload with offline simulation
55
- data (model specs, power templates, ITL fits).
56
- inference_ramps: Inference server ramp schedule. `None` keeps all
57
- servers active.
58
- training: Training workload schedule. `None` disables training
59
- overlay.
60
- """
61
-
62
- inference_data: InferenceData
63
- inference_ramps: InferenceRampSchedule = field(default_factory=InferenceRampSchedule)
64
- training: TrainingSchedule = field(default_factory=TrainingSchedule)
65
-
66
-
67
- class OfflineDatacenter(LLMBatchSizeControlledDatacenter[OfflineDatacenterState]):
68
- """Trace-based datacenter simulation with step-by-step interface.
69
-
70
- Each `step` call computes one timestep of power output by indexing
71
- into pre-built per-GPU templates, applying per-server amplitude
72
- scaling and noise, and summing across active servers per phase.
73
-
74
- Batch size changes via `apply_control` take effect on the next
75
- `step` call.
76
-
77
- If `workload.inference_ramps` is set, a
78
- [`RampActivationPolicy`][openg2g.datacenter.layout.RampActivationPolicy]
79
- is created per model.
80
-
81
- Args:
82
- datacenter: Facility configuration (GPUs per server, base load).
83
- workload: Offline workload configuration bundling inference data,
84
- training overlays, and server ramp events.
85
- dt_s: Simulation timestep (seconds).
86
- seed: Random seed for layout generation, noise, and latency
87
- sampling. Sub-seeds are derived deterministically.
88
- power_augmentation: Per-server amplitude scaling and noise
89
- settings.
90
- """
91
-
92
- def __init__(
93
- self,
94
- datacenter: DatacenterConfig,
95
- workload: OfflineWorkload,
96
- *,
97
- dt_s: Fraction,
98
- seed: int = 0,
99
- power_augmentation: PowerAugmentationConfig | None = None,
100
- ) -> None:
101
- super().__init__()
102
- if power_augmentation is None:
103
- power_augmentation = PowerAugmentationConfig()
104
-
105
- self._datacenter = datacenter
106
- self._workload = workload
107
- self._power_augmentation = power_augmentation
108
- self._dt_s = dt_s
109
- self._seed = int(seed)
110
- self._models = list(workload.inference_data.models)
111
- self._base_W_per_phase = float(datacenter.base_kw_per_phase) * 1e3
112
-
113
- self._layout_rng = np.random.default_rng(self._seed)
114
- self._batch_by_model: dict[str, int] = {ms.model_label: ms.initial_batch_size for ms in self._models}
115
-
116
- self._layouts: dict[str, ServerLayout] = {}
117
- self._policies: dict[str, ActivationPolicy] = {}
118
- self._build_all_layouts()
119
- self._inference_augmenter = InferencePowerAugmenter(
120
- layouts=self._layouts,
121
- policies=self._policies,
122
- seed=self._seed + 12345,
123
- )
124
-
125
- self._global_step: int = 0
126
- self._latency_rng = np.random.default_rng(self._seed + 54321)
127
-
128
- logger.info(
129
- "OfflineDatacenter: %d models, dt=%s s, seed=%d",
130
- len(self._models),
131
- dt_s,
132
- seed,
133
- )
134
- for ms in self._models:
135
- logger.info(
136
- " %s: %d replicas, %d GPUs/replica, batch=%d",
137
- ms.model_label,
138
- ms.num_replicas,
139
- ms.gpus_per_replica,
140
- ms.initial_batch_size,
141
- )
142
-
143
- @property
144
- def dt_s(self) -> Fraction:
145
- return self._dt_s
146
-
147
- def step(self, clock: SimulationClock, events: EventEmitter) -> OfflineDatacenterState:
148
- t_now = clock.time_s
149
- template_store = self._workload.inference_data.power_templates
150
-
151
- # Build per-GPU power dict by indexing into templates with layout offsets.
152
- per_gpu_by_model: dict[str, np.ndarray] = {}
153
- for ms in self._models:
154
- label = ms.model_label
155
- if ms.num_replicas <= 0:
156
- continue
157
- batch = int(self._batch_by_model[label])
158
-
159
- layout = self._layouts[label]
160
- template = template_store.template(label, batch)
161
- indices = (self._global_step + layout.stagger_offsets) % len(template)
162
- per_gpu_by_model[label] = template[indices]
163
-
164
- inference_aug = self._inference_augmenter.augment(per_gpu_by_model, t_now)
165
-
166
- power_by_model = dict(inference_aug.power_by_model_w)
167
- active_replicas_by_model = dict(inference_aug.active_replicas_by_model)
168
- for ms in self._models:
169
- power_by_model.setdefault(ms.model_label, 0.0)
170
- active_replicas_by_model.setdefault(ms.model_label, 0)
171
-
172
- # This is where we accumulate power across workloads.
173
- phase_power = np.array(
174
- [
175
- self._base_W_per_phase + inference_aug.power_w.a,
176
- self._base_W_per_phase + inference_aug.power_w.b,
177
- self._base_W_per_phase + inference_aug.power_w.c,
178
- ]
179
- )
180
-
181
- # Training overlay
182
- for run, t_start, t_end in self._workload.training:
183
- training_power_w = run.eval_power(float(t_now), t_start, t_end)
184
- phase_power += training_power_w / 3.0
185
-
186
- # ITL sampling
187
- itl_fits = self._workload.inference_data.itl_fits
188
- observed_itl_s_by_model: dict[str, float] = {}
189
- for ms in self._models:
190
- label = ms.model_label
191
- n_rep = active_replicas_by_model.get(label, 0)
192
- if itl_fits is None or n_rep <= 0:
193
- observed_itl_s_by_model[label] = float("nan")
194
- continue
195
- batch = int(self._batch_by_model[label])
196
- observed_itl_s_by_model[label] = itl_fits.sample_avg(
197
- model_label=label,
198
- batch_size=batch,
199
- n_replicas=n_rep,
200
- rng=self._latency_rng,
201
- )
202
-
203
- state = OfflineDatacenterState(
204
- time_s=float(t_now),
205
- power_w=ThreePhase(
206
- a=float(phase_power[0]),
207
- b=float(phase_power[1]),
208
- c=float(phase_power[2]),
209
- ),
210
- power_by_model_w=power_by_model,
211
- active_replicas_by_model=active_replicas_by_model,
212
- batch_size_by_model=dict(self._batch_by_model),
213
- observed_itl_s_by_model=observed_itl_s_by_model,
214
- )
215
- self._global_step += 1
216
- return state
217
-
218
- @functools.singledispatchmethod
219
- def apply_control(self, command: DatacenterCommand, events: EventEmitter) -> None:
220
- """Apply a control command. Dispatches on command type."""
221
- raise TypeError(f"OfflineDatacenter does not support {type(command).__name__}")
222
-
223
- @apply_control.register
224
- def apply_control_set_batch_size(self, command: SetBatchSize, events: EventEmitter) -> None:
225
- """Record new batch sizes. Changes take effect on the next step."""
226
- if command.ramp_up_rate_by_model:
227
- raise ValueError(
228
- f"OfflineDatacenter does not support ramp_up_rate_by_model (got {command.ramp_up_rate_by_model}). "
229
- f"Batch size changes are always immediate in trace-based simulation."
230
- )
231
- for label, b in command.batch_size_by_model.items():
232
- b_int = int(b)
233
- if b_int <= 0:
234
- raise ValueError(f"Batch size must be positive for model {label!r}, got {b_int}.")
235
- old = self._batch_by_model.get(str(label))
236
- self._batch_by_model[str(label)] = b_int
237
- if old != b_int:
238
- logger.info("Batch size %s: %s -> %d", label, old, b_int)
239
- events.emit(
240
- "datacenter.batch_size.updated",
241
- {"batch_size_by_model": dict(self._batch_by_model)},
242
- )
243
-
244
- def reset(self) -> None:
245
- self._global_step = 0
246
- self._batch_by_model = {ms.model_label: ms.initial_batch_size for ms in self._models}
247
- self._layout_rng = np.random.default_rng(self._seed)
248
- self._layouts = {}
249
- self._policies = {}
250
- self._build_all_layouts()
251
- self._inference_augmenter = InferencePowerAugmenter(
252
- layouts=self._layouts,
253
- policies=self._policies,
254
- seed=self._seed + 12345,
255
- )
256
- self._latency_rng = np.random.default_rng(self._seed + 54321)
257
-
258
- def _build_all_layouts(self) -> None:
259
- """Build layouts and activation policies for all models."""
260
- schedule = self._workload.inference_ramps
261
- rng = self._layout_rng
262
- gpus_per_server = self._datacenter.gpus_per_server
263
- amp_lo, amp_hi = self._power_augmentation.amplitude_scale_range
264
- noise_fraction = self._power_augmentation.noise_fraction
265
- template_store = self._workload.inference_data.power_templates
266
-
267
- for ms in self._models:
268
- if ms.num_replicas > 0:
269
- any_batch = template_store.batch_sizes(ms.model_label)[0]
270
- tpl_len = len(template_store.template(ms.model_label, any_batch))
271
-
272
- num_servers = math.ceil(ms.num_replicas * ms.gpus_per_replica / gpus_per_server)
273
-
274
- # Phase shuffle
275
- sA, sB, sC = split_integer_evenly(num_servers, 3)
276
- phase_list = np.asarray(([0] * sA) + ([1] * sB) + ([2] * sC), dtype=int)
277
- rng.shuffle(phase_list)
278
-
279
- # Policy dictates which servers are active at a given time.
280
- self._policies[ms.model_label] = RampActivationPolicy(schedule, num_servers, rng)
281
-
282
- # This offset determines for each server, how much to stagger its power template indexing.
283
- stagger_offsets = rng.integers(low=0, high=max(tpl_len, 1), size=num_servers)
284
-
285
- # Amplitude scales
286
- amplitude_scales = rng.uniform(amp_lo, amp_hi, size=num_servers)
287
-
288
- total_gpus = ms.num_replicas * ms.gpus_per_replica
289
- gpus_per_server_list = np.full(num_servers, gpus_per_server, dtype=int)
290
- tail = total_gpus - (num_servers - 1) * gpus_per_server
291
- gpus_per_server_list[-1] = int(tail) if tail > 0 else gpus_per_server
292
-
293
- self._layouts[ms.model_label] = ServerLayout(
294
- num_servers=num_servers,
295
- total_gpus=total_gpus,
296
- gpus_per_replica=ms.gpus_per_replica,
297
- gpus_per_server_list=gpus_per_server_list,
298
- phase_list=phase_list,
299
- stagger_offsets=stagger_offsets,
300
- amplitude_scales=amplitude_scales,
301
- noise_fraction=noise_fraction,
302
- )
303
-
304
- @property
305
- def phase_share_by_model(self) -> dict[str, np.ndarray]:
306
- """Per-model phase share vectors derived from server placement.
307
-
308
- Returns:
309
- Mapping of model label to a 3-element array `[frac_A, frac_B, frac_C]`
310
- representing the fraction of servers on each phase.
311
- """
312
- shares: dict[str, np.ndarray] = {}
313
- for label, layout in self._layouts.items():
314
- counts = np.bincount(layout.phase_list, minlength=3).astype(float)
315
- total = counts.sum()
316
- if total > 0:
317
- shares[label] = counts / total
318
- else:
319
- shares[label] = np.array([1 / 3, 1 / 3, 1 / 3], dtype=float)
320
- return shares
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/datacenter/online.py DELETED
@@ -1,1196 +0,0 @@
1
- """Online (live GPU) datacenter backend with power augmentation.
2
-
3
- Connects to real vLLM inference servers for load generation and ITL
4
- measurement, and to zeusd instances for live GPU power monitoring.
5
- Power readings from a small number of real GPUs are augmented to
6
- datacenter scale using the shared
7
- [`InferencePowerAugmenter`][openg2g.datacenter.workloads.inference.InferencePowerAugmenter]
8
- pipeline.
9
-
10
- Requires `pip install zeus aiohttp`.
11
- """
12
-
13
- from __future__ import annotations
14
-
15
- import asyncio
16
- import collections
17
- import contextlib
18
- import functools
19
- import json
20
- import logging
21
- import math
22
- import re
23
- import threading
24
- import time
25
- import urllib.request
26
- from collections.abc import Sequence
27
- from dataclasses import dataclass, field
28
- from fractions import Fraction
29
- from pathlib import Path
30
- from typing import Any
31
-
32
- import aiohttp
33
- import numpy as np
34
- from pydantic import BaseModel, ConfigDict
35
- from zeus.monitor.power_streaming import PowerStreamingClient
36
- from zeus.utils.zeusd import ZeusdConfig
37
-
38
- from openg2g.clock import SimulationClock
39
- from openg2g.common import ThreePhase
40
- from openg2g.datacenter.base import LLMBatchSizeControlledDatacenter, LLMDatacenterState
41
- from openg2g.datacenter.command import DatacenterCommand, SetBatchSize
42
- from openg2g.datacenter.config import (
43
- DatacenterConfig,
44
- InferenceModelSpec,
45
- InferenceRampSchedule,
46
- PowerAugmentationConfig,
47
- )
48
- from openg2g.datacenter.layout import (
49
- ActivationPolicy,
50
- RampActivationPolicy,
51
- ServerLayout,
52
- )
53
- from openg2g.datacenter.workloads.inference import (
54
- InferencePowerAugmenter,
55
- RequestStore,
56
- )
57
- from openg2g.events import EventEmitter
58
- from openg2g.utils import split_integer_evenly
59
-
60
- logger = logging.getLogger(__name__)
61
-
62
-
63
- @dataclass(frozen=True)
64
- class OnlineDatacenterState(LLMDatacenterState):
65
- """Extended state from the online (live GPU) backend.
66
-
67
- The base `power_w`
68
- field carries the augmented three-phase power (what the grid sees).
69
- This subclass adds the measured (pre-augmentation) breakdown for
70
- post-hoc analysis.
71
-
72
- Attributes:
73
- measured_power_w: Total measured three-phase power from real GPUs
74
- (before augmentation), plus base load.
75
- measured_power_w_by_model: Per-model total measured power from real
76
- GPUs (watts).
77
- augmented_power_w_by_model: Per-model augmented power (watts). This
78
- is the power fed to the grid for each model after scaling up.
79
- augmentation_factor_by_model: Per-model augmentation multiplier
80
- (virtual replicas / real replicas).
81
- prometheus_metrics_by_model: Per-model Prometheus metrics snapshot.
82
- Keys are model labels, values are dicts with metric names like
83
- `num_requests_running`, `num_requests_waiting`,
84
- `kv_cache_usage_perc`, `num_preemptions_total`.
85
- """
86
-
87
- measured_power_w: ThreePhase = field(default_factory=lambda: ThreePhase(a=0.0, b=0.0, c=0.0))
88
- measured_power_w_by_model: dict[str, float] = field(default_factory=dict)
89
- augmented_power_w_by_model: dict[str, float] = field(default_factory=dict)
90
- augmentation_factor_by_model: dict[str, float] = field(default_factory=dict)
91
- prometheus_metrics_by_model: dict[str, dict[str, float]] = field(default_factory=dict)
92
-
93
-
94
- class GPUEndpointMapping(BaseModel):
95
- """Maps a zeusd endpoint to specific GPUs.
96
-
97
- Attributes:
98
- host: Hostname or IP of the zeusd instance.
99
- port: TCP port of the zeusd instance.
100
- gpu_indices: GPU device indices to monitor on this endpoint.
101
- """
102
-
103
- model_config = ConfigDict(frozen=True)
104
-
105
- host: str
106
- port: int = 4938
107
- gpu_indices: tuple[int, ...] = (0,)
108
-
109
- @property
110
- def endpoint_key(self) -> str:
111
- """Return the `host:port` key used by `PowerStreamingClient`."""
112
- return f"{self.host}:{self.port}"
113
-
114
-
115
- class VLLMDeployment(BaseModel):
116
- """Deployment of one LLM model on a vLLM server.
117
-
118
- !!! Warning
119
- vLLM must be a patched version with the `POST /set_max_num_seqs`
120
- endpoint implemented.
121
-
122
- Pairs a reusable
123
- [`InferenceModelSpec`][openg2g.datacenter.config.InferenceModelSpec]
124
- with physical deployment details. `spec.num_replicas` is the
125
- simulated (augmented) count for grid simulation. The real replica
126
- count is derived from `gpu_endpoints` and `spec.gpus_per_replica`.
127
-
128
- Tracks the current batch size (`max_num_seqs`) and provides
129
- `set_batch_size()` to update it on the vLLM server.
130
-
131
- Attributes:
132
- spec: Model specification (shared with offline datacenter).
133
- vllm_base_url: Base URL of the vLLM server (e.g. `http://node1:8000`).
134
- gpu_endpoints: GPU endpoint mappings for power monitoring.
135
- request_extra_body: Extra fields merged into every request dict
136
- for this model (e.g. `chat_template_kwargs`).
137
- batch_size: Current batch size (`max_num_seqs`). Initialized from
138
- `spec.initial_batch_size` if not set explicitly.
139
- """
140
-
141
- spec: InferenceModelSpec
142
- vllm_base_url: str
143
- gpu_endpoints: tuple[GPUEndpointMapping, ...] = ()
144
- request_extra_body: dict[str, Any] | None = None
145
- batch_size: int = 0
146
-
147
- def model_post_init(self, __context: Any) -> None:
148
- if self.batch_size == 0:
149
- self.batch_size = self.spec.initial_batch_size
150
-
151
- @property
152
- def model_label(self) -> str:
153
- return self.spec.model_label
154
-
155
- @property
156
- def num_real_gpus(self) -> int:
157
- """Total number of real GPUs for this model across all endpoints."""
158
- return sum(len(ep.gpu_indices) for ep in self.gpu_endpoints)
159
-
160
- @property
161
- def num_real_replicas(self) -> int:
162
- """Number of real replicas (real GPUs / GPUs per replica)."""
163
- return self.num_real_gpus // max(self.spec.gpus_per_replica, 1)
164
-
165
- @property
166
- def augmentation_factor(self) -> float:
167
- """Ratio of simulated replicas to real replicas."""
168
- return self.spec.num_replicas / max(self.num_real_replicas, 1)
169
-
170
- def set_batch_size(self, batch_size: int, ramp_up_rate: float = 0.0) -> None:
171
- """Update batch size on the vLLM server and track it locally.
172
-
173
- Sends `POST /set_max_num_seqs` to the vLLM server.
174
-
175
- Args:
176
- batch_size: New batch size (max_num_seqs) to set.
177
- ramp_up_rate: Optional ramp-up rate for gradual increase.
178
- """
179
- old = self.batch_size
180
- url = f"{self.vllm_base_url}/set_max_num_seqs?max_num_seqs={batch_size}"
181
- if ramp_up_rate > 0:
182
- url += f"&ramp_up_rate={ramp_up_rate}"
183
- try:
184
- req = urllib.request.Request(url, method="POST", data=b"")
185
- with urllib.request.urlopen(req, timeout=2.0) as resp:
186
- if resp.status >= 400:
187
- raise RuntimeError(
188
- f"Failed to set batch size {batch_size} on {self.vllm_base_url}: HTTP {resp.status}"
189
- )
190
- except Exception:
191
- logger.error(
192
- "Failed to set batch size %d on %s (keeping old=%d)",
193
- batch_size,
194
- self.vllm_base_url,
195
- old,
196
- exc_info=True,
197
- )
198
- raise
199
- self.batch_size = batch_size
200
- if old != batch_size:
201
- logger.info("Batch size %s: %d -> %d", self.model_label, old, batch_size)
202
-
203
-
204
- class LiveServerConfig(BaseModel):
205
- """Configuration for interacting with live vLLM servers.
206
-
207
- Groups settings related to load generation, ITL measurement, and
208
- Prometheus monitoring. The online counterpart of offline's
209
- trace/template data.
210
-
211
- Attributes:
212
- requests_dir: Directory containing per-model JSONL request files
213
- (e.g. `{model_label}.jsonl`). If `None`, a minimal fallback
214
- request is used for each model.
215
- prometheus_poll_interval_s: How often to poll vLLM /metrics for
216
- request counts and saturation monitoring. Set to 0 to disable.
217
- max_output_tokens: Token limit for generated load requests (used
218
- by the fallback request when no JSONL requests are provided).
219
- itl_window_s: Sliding window for ITL averaging (seconds).
220
- """
221
-
222
- model_config = ConfigDict(frozen=True)
223
-
224
- requests_dir: Path | None = None
225
- prometheus_poll_interval_s: float = 0.5
226
- max_output_tokens: int = 512
227
- itl_window_s: float = 1.0
228
-
229
-
230
- STAGGER_BUFFER_S: float = 10.0
231
- """Seconds of power history for temporal staggering.
232
-
233
- Also used as the stagger range when building
234
- [`ServerLayout`][openg2g.datacenter.layout.ServerLayout]
235
- (float offsets drawn from `[0, STAGGER_BUFFER_S)`).
236
-
237
- Not user-configurable. Patchable for testing via
238
- `openg2g.datacenter.online.STAGGER_BUFFER_S = ...`.
239
- """
240
-
241
-
242
- def _check_vllm_health(base_url: str, timeout_s: float = 10.0) -> None:
243
- """Verify a vLLM server is reachable via GET /health.
244
-
245
- Args:
246
- base_url: Base URL of the vLLM server (e.g. `http://node1:8000`).
247
- timeout_s: HTTP timeout in seconds.
248
-
249
- Raises:
250
- RuntimeError: If the server is not reachable or unhealthy.
251
- """
252
- url = f"{base_url}/health"
253
- try:
254
- req = urllib.request.Request(url)
255
- with urllib.request.urlopen(req, timeout=timeout_s) as resp:
256
- if resp.status != 200:
257
- raise RuntimeError(f"vLLM health check failed: HTTP {resp.status} from {url}")
258
- except Exception as e:
259
- raise RuntimeError(f"vLLM health check failed for {url}: {e}") from e
260
-
261
-
262
- def _check_vllm_model(base_url: str, expected_model: str, timeout_s: float = 10.0) -> None:
263
- """Verify a vLLM server is serving the expected model via GET /v1/models.
264
-
265
- Args:
266
- base_url: Base URL of the vLLM server.
267
- expected_model: Model ID to expect in the response.
268
- timeout_s: HTTP timeout in seconds.
269
-
270
- Raises:
271
- RuntimeError: If the model is not served or the endpoint is unreachable.
272
- """
273
- url = f"{base_url}/v1/models"
274
- try:
275
- req = urllib.request.Request(url)
276
- with urllib.request.urlopen(req, timeout=timeout_s) as resp:
277
- if resp.status != 200:
278
- raise RuntimeError(f"vLLM model check failed: HTTP {resp.status} from {url}")
279
- data = json.loads(resp.read().decode())
280
- served = [m["id"] for m in data.get("data", [])]
281
- if expected_model not in served:
282
- raise RuntimeError(f"vLLM at {base_url} serves {served}, expected '{expected_model}'")
283
- except RuntimeError:
284
- raise
285
- except Exception as e:
286
- raise RuntimeError(f"vLLM model check failed for {url}: {e}") from e
287
-
288
-
289
- def _check_zeusd_health(host: str, port: int = 4938, timeout_s: float = 10.0) -> None:
290
- """Verify a zeusd instance is reachable via GET /discover.
291
-
292
- Args:
293
- host: Hostname of the zeusd instance.
294
- port: TCP port.
295
- timeout_s: HTTP timeout in seconds.
296
-
297
- Raises:
298
- RuntimeError: If the zeusd instance is unreachable.
299
- """
300
- url = f"http://{host}:{port}/discover"
301
- try:
302
- req = urllib.request.Request(url)
303
- with urllib.request.urlopen(req, timeout=timeout_s) as resp:
304
- if resp.status != 200:
305
- raise RuntimeError(f"zeusd health check failed: HTTP {resp.status} from {url}")
306
- except RuntimeError:
307
- raise
308
- except Exception as e:
309
- raise RuntimeError(f"zeusd health check failed for {url}: {e}") from e
310
-
311
-
312
- _GAUGE_RE = re.compile(r"^([a-zA-Z_:][a-zA-Z0-9_:]*)\{.*?\}\s+(.+)$|^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+(.+)$")
313
-
314
- _PROMETHEUS_METRICS = (
315
- "vllm:num_requests_running",
316
- "vllm:num_requests_waiting",
317
- "vllm:num_preemptions_total",
318
- "vllm:kv_cache_usage_perc",
319
- )
320
-
321
-
322
- def _parse_prometheus_text(text: str) -> dict[str, float]:
323
- """Parse Prometheus text-format metrics and extract vLLM gauges.
324
-
325
- Returns a dict with metric names (without `vllm:` prefix) mapped to
326
- their summed values.
327
- """
328
- raw: dict[str, float] = {}
329
- for line in text.splitlines():
330
- line = line.strip()
331
- if not line or line.startswith("#"):
332
- continue
333
- m = _GAUGE_RE.match(line)
334
- if m:
335
- name = m.group(1) or m.group(3)
336
- val_str = m.group(2) or m.group(4)
337
- if name in _PROMETHEUS_METRICS:
338
- with contextlib.suppress(ValueError):
339
- raw[name] = raw.get(name, 0.0) + float(val_str)
340
-
341
- result: dict[str, float] = {}
342
- for metric in _PROMETHEUS_METRICS:
343
- if metric in raw:
344
- short = metric.removeprefix("vllm:")
345
- result[short] = raw[metric]
346
- return result
347
-
348
-
349
- class _PrometheusPoller:
350
- """Polls vLLM /metrics endpoints for Prometheus gauges.
351
-
352
- Runs as an async task inside `_LoadGenerator`'s event loop.
353
- Provides thread-safe access to the latest snapshot per model.
354
- """
355
-
356
- def __init__(
357
- self,
358
- deployments: Sequence[VLLMDeployment],
359
- poll_interval_s: float = 0.5,
360
- ) -> None:
361
- self._deployments = {d.model_label: d for d in deployments}
362
- self._poll_interval_s = poll_interval_s
363
- self._lock = threading.Lock()
364
- self._latest: dict[str, dict[str, float]] = {}
365
-
366
- def get_latest(self) -> dict[str, dict[str, float]]:
367
- """Return the latest metrics snapshot per model (thread-safe)."""
368
- with self._lock:
369
- return dict(self._latest)
370
-
371
- async def run(self, stop_event: threading.Event) -> None:
372
- """Poll loop. Call as an asyncio task."""
373
- async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=5.0)) as session:
374
- while not stop_event.is_set():
375
- for label, dep in self._deployments.items():
376
- url = f"{dep.vllm_base_url}/metrics"
377
- try:
378
- async with session.get(url) as resp:
379
- if resp.status == 200:
380
- text = await resp.text()
381
- metrics = _parse_prometheus_text(text)
382
- with self._lock:
383
- self._latest[label] = metrics
384
- except Exception:
385
- logger.debug("Prometheus poll failed for %s", label, exc_info=True)
386
- await asyncio.sleep(self._poll_interval_s)
387
-
388
-
389
- class _LoadGenerator:
390
- """Background load generator that saturates vLLM servers and measures ITL.
391
-
392
- Runs a daemon thread with an asyncio event loop. For each model, a
393
- semaphore-gated producer loop cycles through pre-built request dicts
394
- endlessly. The semaphore size is `2 * max(feasible_batch_sizes)`,
395
- ensuring the vLLM queue never drains even at the largest batch size
396
- the OFO controller can set. Per-token inter-token latency (ITL) is
397
- measured from SSE chunk arrival times using `usage.completion_tokens`
398
- increments; first-token latency (TTFT) is excluded from ITL samples.
399
- """
400
-
401
- def __init__(
402
- self,
403
- deployments: Sequence[VLLMDeployment],
404
- *,
405
- request_store: RequestStore | None = None,
406
- max_output_tokens: int = 512,
407
- itl_window_s: float = 1.0,
408
- prometheus_poller: _PrometheusPoller | None = None,
409
- ) -> None:
410
- self._deployments = {d.model_label: d for d in deployments}
411
- self._requests: dict[str, list[dict]] = {}
412
- if request_store is not None:
413
- self._requests = dict(request_store.requests_by_model)
414
- self._max_output_tokens = max_output_tokens
415
- self._itl_window_s = itl_window_s
416
- self._prometheus = prometheus_poller
417
-
418
- self._lock = threading.Lock()
419
- self._itl_samples: dict[str, collections.deque[tuple[float, float]]] = {}
420
- for d in deployments:
421
- self._itl_samples[d.model_label] = collections.deque()
422
-
423
- self._thread: threading.Thread | None = None
424
- self._stop_event = threading.Event()
425
- self._loop: asyncio.AbstractEventLoop | None = None
426
-
427
- def start(self) -> None:
428
- if self._thread is not None:
429
- raise RuntimeError("LoadGenerator already started")
430
- self._stop_event.clear()
431
- self._thread = threading.Thread(
432
- target=self._run_thread,
433
- name="load-generator",
434
- daemon=True,
435
- )
436
- self._thread.start()
437
-
438
- def stop(self) -> None:
439
- self._stop_event.set()
440
- if self._loop is not None:
441
- self._loop.call_soon_threadsafe(self._loop.stop)
442
- if self._thread is not None:
443
- self._thread.join(timeout=10.0)
444
- self._thread = None
445
-
446
- def get_observed_itl(self, model_label: str, window_s: float | None = None) -> float:
447
- """Return the windowed-average ITL for *model_label*, or NaN."""
448
- if window_s is None:
449
- window_s = self._itl_window_s
450
- cutoff = time.monotonic() - window_s
451
- with self._lock:
452
- samples = self._itl_samples.get(model_label)
453
- if not samples:
454
- return float("nan")
455
- recent = [itl for ts, itl in samples if ts >= cutoff]
456
- if not recent:
457
- return float("nan")
458
- return sum(recent) / len(recent)
459
-
460
- def _run_thread(self) -> None:
461
- self._loop = asyncio.new_event_loop()
462
- asyncio.set_event_loop(self._loop)
463
- try:
464
- self._loop.run_until_complete(self._run_async())
465
- except Exception:
466
- if not self._stop_event.is_set():
467
- logger.exception("LoadGenerator thread crashed")
468
- finally:
469
- self._loop.close()
470
- self._loop = None
471
-
472
- async def _run_async(self) -> None:
473
- tasks: list[asyncio.Task] = []
474
-
475
- for label, dep in self._deployments.items():
476
- tasks.append(asyncio.create_task(self._model_producer(label, dep)))
477
-
478
- if self._prometheus is not None:
479
- tasks.append(asyncio.create_task(self._prometheus.run(self._stop_event)))
480
-
481
- while not self._stop_event.is_set():
482
- await asyncio.sleep(0.1)
483
-
484
- for t in tasks:
485
- t.cancel()
486
- await asyncio.gather(*tasks, return_exceptions=True)
487
-
488
- async def _model_producer(self, label: str, dep: VLLMDeployment) -> None:
489
- """Semaphore-gated loop that continuously submits requests for one model.
490
-
491
- Cycles through the JSONL request list endlessly. The semaphore
492
- limits in-flight requests to `2 * max(feasible_batch_sizes)`,
493
- ensuring the vLLM server always has a non-empty queue.
494
- """
495
- max_batch = max(dep.spec.feasible_batch_sizes)
496
- sem = asyncio.Semaphore(2 * max_batch)
497
- requests = self._requests.get(label, [])
498
- req_idx = 0
499
- active: set[asyncio.Task[None]] = set()
500
-
501
- connector = aiohttp.TCPConnector(limit=0, ssl=False)
502
- async with aiohttp.ClientSession(
503
- timeout=aiohttp.ClientTimeout(total=300.0),
504
- connector=connector,
505
- ) as session:
506
- while not self._stop_event.is_set():
507
- await sem.acquire()
508
- if self._stop_event.is_set():
509
- break
510
- if requests:
511
- request_dict = requests[req_idx % len(requests)]
512
- req_idx += 1
513
- else:
514
- request_dict = self._default_request(dep)
515
- task = asyncio.create_task(self._single_request(label, dep, request_dict, session, sem))
516
- active.add(task)
517
- task.add_done_callback(active.discard)
518
-
519
- def _default_request(self, dep: VLLMDeployment) -> dict:
520
- """Build a minimal fallback request dict."""
521
- return {
522
- "model": dep.spec.model_id,
523
- "messages": [{"role": "user", "content": "Hello, how are you?"}],
524
- "max_completion_tokens": self._max_output_tokens,
525
- }
526
-
527
- async def _single_request(
528
- self,
529
- label: str,
530
- dep: VLLMDeployment,
531
- request_dict: dict,
532
- session: aiohttp.ClientSession,
533
- sem: asyncio.Semaphore,
534
- ) -> None:
535
- """Send one streaming chat-completion request and measure decoding ITL.
536
-
537
- Uses `usage.completion_tokens` increments to correctly handle
538
- multi-token bundles. First-token samples (TTFT) are skipped;
539
- only decoding-phase ITL is recorded.
540
- """
541
- try:
542
- url = f"{dep.vllm_base_url}/v1/chat/completions"
543
- body = dict(request_dict)
544
- body["stream"] = True
545
- body["stream_options"] = {"include_usage": True, "continuous_usage_stats": True}
546
- if "max_tokens" in body and "max_completion_tokens" not in body:
547
- body["max_completion_tokens"] = body.pop("max_tokens")
548
-
549
- current_completion_tokens = 0
550
- most_recent_timestamp = time.perf_counter()
551
- ttft_recorded = False
552
-
553
- async with session.post(url, json=body) as response:
554
- if response.status != 200:
555
- return
556
- async for chunk_bytes in response.content:
557
- if self._stop_event.is_set():
558
- return
559
- chunk_bytes = chunk_bytes.strip()
560
- if not chunk_bytes:
561
- continue
562
-
563
- chunk_str = chunk_bytes.decode("utf-8")
564
-
565
- if chunk_str.startswith(":"):
566
- continue
567
-
568
- data_str = chunk_str.removeprefix("data: ")
569
- if data_str == "[DONE]":
570
- break
571
-
572
- try:
573
- data = json.loads(data_str)
574
- except json.JSONDecodeError:
575
- continue
576
-
577
- usage = data.get("usage")
578
- completion_tokens = usage and usage.get("completion_tokens")
579
- if not completion_tokens:
580
- continue
581
-
582
- timestamp = time.perf_counter()
583
-
584
- if not ttft_recorded:
585
- ttft_recorded = True
586
- current_completion_tokens = completion_tokens
587
- else:
588
- itl = timestamp - most_recent_timestamp
589
- inc = completion_tokens - current_completion_tokens
590
- current_completion_tokens = completion_tokens
591
-
592
- now_mono = time.monotonic()
593
- with self._lock:
594
- self._itl_samples[label].append((now_mono, itl))
595
- for _ in range(max(inc - 1, 0)):
596
- self._itl_samples[label].append((now_mono, 0.0))
597
-
598
- most_recent_timestamp = timestamp
599
-
600
- except Exception:
601
- if not self._stop_event.is_set():
602
- logger.debug("Request to %s failed for %s", dep.vllm_base_url, label, exc_info=True)
603
- finally:
604
- sem.release()
605
-
606
-
607
- class _RollingPowerBuffer:
608
- """Per-model rolling buffer of (timestamp, per_gpu_watts) readings.
609
-
610
- Provides `sample_servers()` to look up historical per-GPU power at
611
- different time offsets for each virtual server, enabling temporal
612
- staggering of batch-size-change transients.
613
- """
614
-
615
- def __init__(self, model_labels: Sequence[str], max_samples: int = 10000) -> None:
616
- self._buffers: dict[str, collections.deque[tuple[float, float]]] = {
617
- label: collections.deque(maxlen=max_samples) for label in model_labels
618
- }
619
-
620
- def append(self, label: str, timestamp: float, per_gpu_w: float) -> None:
621
- """Feed a new per-GPU power reading for a model."""
622
- self._buffers[label].append((timestamp, per_gpu_w))
623
-
624
- def sample_servers(
625
- self,
626
- label: str,
627
- now: float,
628
- stagger_offsets: np.ndarray,
629
- ) -> np.ndarray:
630
- """Look up per-GPU power at `now - offset[i]` for each virtual server.
631
-
632
- Args:
633
- label: Model label.
634
- now: Current wall-clock time (monotonic).
635
- stagger_offsets: Per-server time offsets (seconds), shape `(N,)`.
636
-
637
- Returns:
638
- Array of shape `(N,)` with per-GPU power for each server.
639
- """
640
- buf = self._buffers[label]
641
- n = len(stagger_offsets)
642
- result = np.zeros(n, dtype=float)
643
- if not buf:
644
- return result
645
- for i in range(n):
646
- result[i] = self._lookup(buf, now - stagger_offsets[i])
647
- return result
648
-
649
- def clear(self) -> None:
650
- """Clear all buffers."""
651
- for buf in self._buffers.values():
652
- buf.clear()
653
-
654
- @staticmethod
655
- def _lookup(buf: collections.deque[tuple[float, float]], target_t: float) -> float:
656
- """Find the power reading at or just before `target_t`."""
657
- if not buf:
658
- return 0.0
659
- if target_t <= buf[0][0]:
660
- return buf[0][1]
661
- if target_t >= buf[-1][0]:
662
- return buf[-1][1]
663
- for i in range(len(buf) - 1, -1, -1):
664
- if buf[i][0] <= target_t:
665
- return buf[i][1]
666
- return buf[0][1]
667
-
668
-
669
- class OnlineDatacenter(LLMBatchSizeControlledDatacenter[OnlineDatacenterState]):
670
- """Live GPU datacenter backend with power augmentation.
671
-
672
- Dispatches inference load to vLLM servers, streams GPU power from
673
- zeusd, measures ITL from streaming responses, and augments power
674
- readings to datacenter scale using the shared
675
- [`InferencePowerAugmenter`][openg2g.datacenter.workloads.inference.InferencePowerAugmenter]
676
- pipeline (same as
677
- [`OfflineDatacenter`][openg2g.datacenter.offline.OfflineDatacenter]).
678
-
679
- Call [`start`][.start] before the first [`step`][.step] and
680
- [`stop`][.stop] after the simulation loop finishes.
681
-
682
- `PowerStreamingClient` is constructed internally from the GPU
683
- endpoints declared in each deployment. Health checks are always
684
- performed during [`start`][.start].
685
-
686
- Args:
687
- datacenter: Facility configuration (GPUs per server, base load).
688
- deployments: Model deployments with physical hardware mapping.
689
- dt_s: Simulation timestep (seconds).
690
- seed: Random seed for layout generation and noise.
691
- power_augmentation: Per-server amplitude scaling and noise
692
- settings.
693
- inference_ramps: Inference server ramp event(s). `None` keeps
694
- all servers active.
695
- live_server: Configuration for interacting with live vLLM
696
- servers. Request data is loaded from
697
- `LiveServerConfig.requests_dir`.
698
- """
699
-
700
- def __init__(
701
- self,
702
- datacenter: DatacenterConfig,
703
- deployments: Sequence[VLLMDeployment],
704
- *,
705
- dt_s: Fraction = Fraction(1, 10),
706
- seed: int = 0,
707
- power_augmentation: PowerAugmentationConfig | None = None,
708
- inference_ramps: InferenceRampSchedule | None = None,
709
- live_server: LiveServerConfig | None = None,
710
- ) -> None:
711
- super().__init__()
712
- if power_augmentation is None:
713
- power_augmentation = PowerAugmentationConfig()
714
- if live_server is None:
715
- live_server = LiveServerConfig()
716
- self._dt_s = dt_s
717
- self._seed = int(seed)
718
- self._deployments = list(deployments)
719
- self._deployment_map = {d.model_label: d for d in deployments}
720
- self._datacenter_config = datacenter
721
- self._power_augmentation = power_augmentation
722
- self._live_server_config = live_server
723
-
724
- self._base_W_per_phase = float(datacenter.base_kw_per_phase) * 1e3
725
- self._inference_ramp_schedule = inference_ramps if inference_ramps is not None else InferenceRampSchedule()
726
-
727
- servers_by_key: dict[str, ZeusdConfig] = {}
728
- gpu_indices_by_key: dict[str, list[int]] = {}
729
- for d in self._deployments:
730
- for ep in d.gpu_endpoints:
731
- key = ep.endpoint_key
732
- if key not in gpu_indices_by_key:
733
- gpu_indices_by_key[key] = []
734
- for idx in ep.gpu_indices:
735
- if idx not in gpu_indices_by_key[key]:
736
- gpu_indices_by_key[key].append(idx)
737
- servers_by_key[key] = ZeusdConfig.tcp(
738
- ep.host,
739
- ep.port,
740
- gpu_indices=gpu_indices_by_key[key],
741
- cpu_indices=[],
742
- )
743
- self._power_client = PowerStreamingClient(servers=list(servers_by_key.values()))
744
-
745
- self._prometheus = (
746
- _PrometheusPoller(
747
- deployments,
748
- poll_interval_s=live_server.prometheus_poll_interval_s,
749
- )
750
- if live_server.prometheus_poll_interval_s > 0
751
- else None
752
- )
753
-
754
- self._request_store = RequestStore.load(live_server.requests_dir) if live_server.requests_dir else None
755
- self._load_gen = _LoadGenerator(
756
- deployments,
757
- request_store=self._request_store,
758
- max_output_tokens=live_server.max_output_tokens,
759
- itl_window_s=live_server.itl_window_s,
760
- prometheus_poller=self._prometheus,
761
- )
762
-
763
- self._layout_rng = np.random.default_rng(self._seed)
764
- self._layouts: dict[str, ServerLayout] = {}
765
- self._policies: dict[str, ActivationPolicy] = {}
766
- self._build_all_layouts()
767
- self._inference_augmenter = InferencePowerAugmenter(
768
- layouts=self._layouts,
769
- policies=self._policies,
770
- seed=self._seed + 12345,
771
- )
772
- self._rolling_buffer = _RollingPowerBuffer(
773
- [d.model_label for d in deployments],
774
- max_samples=max(int(STAGGER_BUFFER_S * 100), 1000),
775
- )
776
-
777
- self._started = False
778
-
779
- logger.info(
780
- "OnlineDatacenter: %d deployments, dt=%s s",
781
- len(self._deployments),
782
- dt_s,
783
- )
784
- for d in deployments:
785
- layout = self._layouts.get(d.model_label)
786
- n_servers = layout.num_servers if layout else 0
787
- logger.info(
788
- " %s: %d real GPUs, %d simulated replicas (%.0fx augmentation), %d virtual servers, vllm=%s",
789
- d.model_label,
790
- d.num_real_gpus,
791
- d.spec.num_replicas,
792
- d.augmentation_factor,
793
- n_servers,
794
- d.vllm_base_url,
795
- )
796
-
797
- def _build_all_layouts(self) -> None:
798
- """Build ServerLayout and activation policies for each deployed model.
799
-
800
- The RNG invocation order per model must be: phase shuffle,
801
- priority shuffle, stagger offsets, amplitude scales. We
802
- interleave policy construction between the phase shuffle
803
- and stagger/amplitude draws to preserve this ordering.
804
- """
805
- schedule = self._inference_ramp_schedule
806
- gpus_per_server = self._datacenter_config.gpus_per_server
807
- rng = self._layout_rng
808
- amp_lo, amp_hi = self._power_augmentation.amplitude_scale_range
809
- noise_fraction = self._power_augmentation.noise_fraction
810
- stagger_s = float(STAGGER_BUFFER_S)
811
-
812
- for d in self._deployments:
813
- spec = d.spec
814
- if spec.num_replicas > 0:
815
- num_servers = math.ceil(spec.num_replicas * spec.gpus_per_replica / gpus_per_server)
816
-
817
- # Phase shuffle (consumes RNG)
818
- sA, sB, sC = split_integer_evenly(num_servers, 3)
819
- phase_list = np.asarray(([0] * sA) + ([1] * sB) + ([2] * sC), dtype=int)
820
- rng.shuffle(phase_list)
821
-
822
- # Priority shuffle (consumes RNG) — must happen here
823
- self._policies[d.model_label] = RampActivationPolicy(
824
- schedule,
825
- num_servers,
826
- rng,
827
- )
828
-
829
- # Stagger offsets (consumes RNG) — float for online
830
- stagger_offsets = rng.uniform(0.0, max(stagger_s, 1e-9), size=num_servers)
831
-
832
- # Amplitude scales (consumes RNG)
833
- amplitude_scales = rng.uniform(amp_lo, amp_hi, size=num_servers)
834
-
835
- total_gpus = spec.num_replicas * spec.gpus_per_replica
836
- gpus_per_server_list = np.full(num_servers, gpus_per_server, dtype=int)
837
- tail = total_gpus - (num_servers - 1) * gpus_per_server
838
- gpus_per_server_list[-1] = int(tail) if tail > 0 else gpus_per_server
839
-
840
- self._layouts[d.model_label] = ServerLayout(
841
- num_servers=num_servers,
842
- total_gpus=total_gpus,
843
- gpus_per_replica=spec.gpus_per_replica,
844
- gpus_per_server_list=gpus_per_server_list,
845
- phase_list=phase_list,
846
- stagger_offsets=stagger_offsets,
847
- amplitude_scales=amplitude_scales,
848
- noise_fraction=noise_fraction,
849
- )
850
-
851
- @property
852
- def dt_s(self) -> Fraction:
853
- return self._dt_s
854
-
855
- @property
856
- def phase_share_by_model(self) -> dict[str, np.ndarray]:
857
- """Per-model phase share vectors derived from server layout."""
858
- shares: dict[str, np.ndarray] = {}
859
- for label, layout in self._layouts.items():
860
- counts = np.bincount(layout.phase_list, minlength=3).astype(float)
861
- total = counts.sum()
862
- if total > 0:
863
- shares[label] = counts / total
864
- else:
865
- shares[label] = np.array([1 / 3, 1 / 3, 1 / 3], dtype=float)
866
- return shares
867
-
868
- def reset(self) -> None:
869
- if self._started:
870
- self._load_gen.stop()
871
- self._load_gen = _LoadGenerator(
872
- self._deployments,
873
- request_store=self._request_store,
874
- max_output_tokens=self._live_server_config.max_output_tokens,
875
- itl_window_s=self._live_server_config.itl_window_s,
876
- prometheus_poller=self._prometheus,
877
- )
878
- self._layout_rng = np.random.default_rng(self._seed)
879
- self._layouts = {}
880
- self._policies = {}
881
- self._build_all_layouts()
882
- self._inference_augmenter = InferencePowerAugmenter(
883
- layouts=self._layouts,
884
- policies=self._policies,
885
- seed=self._seed + 12345,
886
- )
887
- self._rolling_buffer.clear()
888
- for d in self._deployments:
889
- d.batch_size = d.spec.initial_batch_size
890
- self._started = False
891
-
892
- def start(self) -> None:
893
- """Start load generation, warm up servers, and fill the power buffer.
894
-
895
- Sequence:
896
- 1. Run health checks on all vLLM servers and zeusd instances.
897
- 2. Wait for at least one power reading per endpoint (10 s timeout).
898
- 3. Set initial batch sizes on all vLLM servers.
899
- 4. Start load generation threads.
900
- 5. Warm up: poll power into the rolling buffer while waiting for
901
- each model's `num_requests_running` to reach 95% of its
902
- `initial_batch_size`. Fails after 60 s if any model does not
903
- saturate.
904
- """
905
- if self._started:
906
- raise RuntimeError("OnlineDatacenter already started")
907
-
908
- logger.info("Starting OnlineDatacenter with %d deployments", len(self._deployments))
909
-
910
- # 1. Health checks
911
- logger.info("Running health checks...")
912
- for d in self._deployments:
913
- _check_vllm_health(d.vllm_base_url)
914
- _check_vllm_model(d.vllm_base_url, d.spec.model_id)
915
- for ep in d.gpu_endpoints:
916
- _check_zeusd_health(ep.host, ep.port)
917
- logger.info("All health checks passed")
918
-
919
- # 2. Wait for power readings from all endpoints
920
- all_endpoints: set[str] = set()
921
- for d in self._deployments:
922
- for ep in d.gpu_endpoints:
923
- all_endpoints.add(ep.endpoint_key)
924
-
925
- deadline = time.monotonic() + 10.0
926
- while time.monotonic() < deadline:
927
- readings = self._power_client.get_power()
928
- if all_endpoints.issubset(readings.keys()):
929
- logger.info("Power readings received from all %d endpoints", len(all_endpoints))
930
- break
931
- time.sleep(0.5)
932
- else:
933
- connected = set(self._power_client.get_power().keys())
934
- missing = all_endpoints - connected
935
- logger.warning("Timed out waiting for power readings from: %s", missing)
936
-
937
- # 3. Set initial batch sizes on vLLM servers
938
- for d in self._deployments:
939
- d.set_batch_size(d.spec.initial_batch_size)
940
-
941
- # 4. Start load generation (and Prometheus poller)
942
- self._load_gen.start()
943
- logger.info("LoadGenerator started")
944
-
945
- # 5. Warm up: fill power buffer + wait for server saturation
946
- self._warmup()
947
-
948
- self._started = True
949
- logger.info("OnlineDatacenter ready")
950
-
951
- def _poll_power_into_buffer(self) -> tuple[float, dict[str, float]]:
952
- """Read GPU power from all endpoints and feed the rolling buffer.
953
-
954
- Returns:
955
- Tuple of (monotonic timestamp, per-model average per-GPU watts).
956
- """
957
- now = time.monotonic()
958
- raw_power = self._power_client.get_power()
959
- per_gpu_by_model: dict[str, float] = {}
960
- for d in self._deployments:
961
- total_w = 0.0
962
- n_gpus = 0
963
- for ep in d.gpu_endpoints:
964
- pr = raw_power.get(ep.endpoint_key)
965
- if pr is None:
966
- continue
967
- for idx in ep.gpu_indices:
968
- if idx in pr.gpu_power_w:
969
- total_w += pr.gpu_power_w[idx]
970
- n_gpus += 1
971
- per_gpu_w = total_w / n_gpus if n_gpus > 0 else 0.0
972
- self._rolling_buffer.append(d.model_label, now, per_gpu_w)
973
- per_gpu_by_model[d.model_label] = per_gpu_w
974
- return now, per_gpu_by_model
975
-
976
- def _warmup(
977
- self,
978
- timeout_s: float = 60.0,
979
- saturation_threshold: float = 0.95,
980
- poll_interval_s: float = 0.1,
981
- ) -> None:
982
- """Fill the rolling power buffer and wait for vLLM server saturation.
983
-
984
- Actively polls GPU power to fill the rolling buffer while monitoring
985
- Prometheus `num_requests_running` to verify each model has reached
986
- `saturation_threshold` of its `initial_batch_size`.
987
-
988
- Completion requires both conditions for every model:
989
- 1. `num_requests_running >= saturation_threshold * initial_batch_size`
990
- 2. At least `stagger_buffer_s` has elapsed since that model first
991
- reached saturation (so the buffer contains a full stagger
992
- window of steady-state power data).
993
-
994
- Args:
995
- timeout_s: Maximum warmup duration in seconds.
996
- saturation_threshold: Fraction of `initial_batch_size` that
997
- `num_requests_running` must reach (0.0-1.0).
998
- poll_interval_s: Seconds between power polls.
999
-
1000
- Raises:
1001
- RuntimeError: If any model fails to saturate within `timeout_s`.
1002
- Includes the `num_requests_running` trajectory for failed
1003
- models.
1004
- """
1005
- stagger_s = STAGGER_BUFFER_S
1006
- logger.info(
1007
- "Warming up: waiting for server saturation (%.0f%% of initial_batch_size) "
1008
- "+ %.1f s buffer fill per model...",
1009
- saturation_threshold * 100,
1010
- stagger_s,
1011
- )
1012
-
1013
- warmup_start = time.monotonic()
1014
- deadline = warmup_start + timeout_s
1015
- last_log = warmup_start
1016
-
1017
- trajectory: dict[str, list[tuple[float, float]]] = {d.model_label: [] for d in self._deployments}
1018
- saturation_time: dict[str, float | None] = {d.model_label: None for d in self._deployments}
1019
-
1020
- while time.monotonic() < deadline:
1021
- now = time.monotonic()
1022
- elapsed = now - warmup_start
1023
-
1024
- self._poll_power_into_buffer()
1025
-
1026
- all_ready = True
1027
- if self._prometheus is not None:
1028
- prom = self._prometheus.get_latest()
1029
- for d in self._deployments:
1030
- label = d.model_label
1031
- running = prom.get(label, {}).get("num_requests_running", 0.0)
1032
- trajectory[label].append((elapsed, running))
1033
- target = d.spec.initial_batch_size * saturation_threshold
1034
-
1035
- if running >= target and saturation_time[label] is None:
1036
- saturation_time[label] = now
1037
- logger.info(
1038
- " %s saturated at t=%.1f s (num_requests_running=%.0f)",
1039
- label,
1040
- elapsed,
1041
- running,
1042
- )
1043
-
1044
- sat_t = saturation_time[label]
1045
- if sat_t is None or (now - sat_t) < stagger_s:
1046
- all_ready = False
1047
- else:
1048
- logger.warning(
1049
- "Prometheus polling is disabled; cannot verify server saturation. "
1050
- "Waiting %.1f s for power buffer only.",
1051
- stagger_s,
1052
- )
1053
- if elapsed < stagger_s:
1054
- all_ready = False
1055
-
1056
- if all_ready:
1057
- logger.info("Warmup complete in %.1f s", elapsed)
1058
- return
1059
-
1060
- if now - last_log >= 10.0:
1061
- last_log = now
1062
- if self._prometheus is not None:
1063
- prom = self._prometheus.get_latest()
1064
- for d in self._deployments:
1065
- label = d.model_label
1066
- running = prom.get(label, {}).get("num_requests_running", 0.0)
1067
- target = d.spec.initial_batch_size
1068
- sat_t = saturation_time[label]
1069
- buf_s = (now - sat_t) if sat_t is not None else 0.0
1070
- logger.info(
1071
- " Warmup %s: num_requests_running=%.0f / %d (%.0f%%), buffer=%.1f / %.1f s",
1072
- label,
1073
- running,
1074
- target,
1075
- running / max(target, 1) * 100,
1076
- buf_s,
1077
- stagger_s,
1078
- )
1079
-
1080
- time.sleep(poll_interval_s)
1081
-
1082
- if self._prometheus is None:
1083
- raise RuntimeError(
1084
- f"Warmup timed out after {timeout_s:.0f} s waiting for power buffer to fill ({stagger_s:.1f} s)"
1085
- )
1086
-
1087
- prom = self._prometheus.get_latest()
1088
- failed: list[str] = []
1089
- for d in self._deployments:
1090
- label = d.model_label
1091
- running = prom.get(label, {}).get("num_requests_running", 0.0)
1092
- sat_t = saturation_time[label]
1093
- not_saturated = running < d.spec.initial_batch_size * saturation_threshold
1094
- not_buffered = sat_t is None or (time.monotonic() - sat_t) < stagger_s
1095
- if not_saturated or not_buffered:
1096
- failed.append(label)
1097
-
1098
- parts = [
1099
- f"Warmup timed out after {timeout_s:.0f} s. "
1100
- f"Models that failed to reach {saturation_threshold:.0%} of initial_batch_size:",
1101
- ]
1102
- for label in failed:
1103
- target = self._deployment_map[label].spec.initial_batch_size
1104
- traj = trajectory[label]
1105
- final = traj[-1][1] if traj else 0.0
1106
- parts.append(f" {label} (target: {target}, reached: {final:.0f}):")
1107
- step = max(1, int(5.0 / poll_interval_s))
1108
- samples = traj[::step]
1109
- if traj and (not samples or samples[-1] is not traj[-1]):
1110
- samples.append(traj[-1])
1111
- entries = [f"t={t:.0f}s: {r:.0f}" for t, r in samples]
1112
- parts.append(" " + ", ".join(entries))
1113
- raise RuntimeError("\n".join(parts))
1114
-
1115
- def stop(self) -> None:
1116
- """Stop load generation and power streaming."""
1117
- self._load_gen.stop()
1118
- self._power_client.stop()
1119
- self._started = False
1120
- logger.info("OnlineDatacenter stopped")
1121
-
1122
- def step(self, clock: SimulationClock, events: EventEmitter) -> OnlineDatacenterState:
1123
- """Read live power, augment to datacenter scale, and return state."""
1124
- now, per_gpu_w_by_model = self._poll_power_into_buffer()
1125
-
1126
- measured_power_by_model: dict[str, float] = {}
1127
- augmentation_factor_by_model: dict[str, float] = {}
1128
- for d in self._deployments:
1129
- label = d.model_label
1130
- measured_power_by_model[label] = per_gpu_w_by_model.get(label, 0.0) * d.num_real_gpus
1131
- augmentation_factor_by_model[label] = d.augmentation_factor
1132
-
1133
- per_gpu_by_model: dict[str, np.ndarray] = {}
1134
- for d in self._deployments:
1135
- label = d.model_label
1136
- if label not in self._layouts:
1137
- continue
1138
- layout = self._layouts[label]
1139
- per_gpu_by_model[label] = self._rolling_buffer.sample_servers(label, now, layout.stagger_offsets)
1140
-
1141
- inference_aug = self._inference_augmenter.augment(per_gpu_by_model, clock.time_s)
1142
-
1143
- measured_total = sum(measured_power_by_model.values())
1144
- measured_per_phase = measured_total / 3.0
1145
-
1146
- observed_itl: dict[str, float] = {
1147
- d.model_label: self._load_gen.get_observed_itl(d.model_label) for d in self._deployments
1148
- }
1149
-
1150
- prometheus_metrics: dict[str, dict[str, float]] = {}
1151
- if self._prometheus is not None:
1152
- prometheus_metrics = self._prometheus.get_latest()
1153
-
1154
- state = OnlineDatacenterState(
1155
- time_s=clock.time_s,
1156
- power_w=ThreePhase(
1157
- a=self._base_W_per_phase + inference_aug.power_w.a,
1158
- b=self._base_W_per_phase + inference_aug.power_w.b,
1159
- c=self._base_W_per_phase + inference_aug.power_w.c,
1160
- ),
1161
- batch_size_by_model={d.model_label: d.batch_size for d in self._deployments},
1162
- active_replicas_by_model=inference_aug.active_replicas_by_model,
1163
- observed_itl_s_by_model=observed_itl,
1164
- measured_power_w=ThreePhase(
1165
- a=measured_per_phase + self._base_W_per_phase,
1166
- b=measured_per_phase + self._base_W_per_phase,
1167
- c=measured_per_phase + self._base_W_per_phase,
1168
- ),
1169
- measured_power_w_by_model=measured_power_by_model,
1170
- augmented_power_w_by_model=inference_aug.power_by_model_w,
1171
- augmentation_factor_by_model=augmentation_factor_by_model,
1172
- prometheus_metrics_by_model=prometheus_metrics,
1173
- )
1174
- return state
1175
-
1176
- @functools.singledispatchmethod
1177
- def apply_control(self, command: DatacenterCommand, events: EventEmitter) -> None:
1178
- """Apply a control command. Dispatches on command type."""
1179
- raise TypeError(f"OnlineDatacenter does not support {type(command).__name__}")
1180
-
1181
- @apply_control.register
1182
- def apply_control_set_batch_size(self, command: SetBatchSize, events: EventEmitter) -> None:
1183
- """Apply batch size command by sending HTTP requests to vLLM servers."""
1184
- for label, b in command.batch_size_by_model.items():
1185
- label = str(label)
1186
- b_int = int(b)
1187
- if b_int <= 0:
1188
- raise ValueError(f"Batch size must be positive for model {label!r}, got {b_int}.")
1189
- dep = self._deployment_map.get(label)
1190
- if dep is not None:
1191
- dep.set_batch_size(b_int, ramp_up_rate=command.ramp_up_rate_by_model.get(label, 0.0))
1192
-
1193
- events.emit(
1194
- "datacenter.batch_size.updated",
1195
- {"batch_size_by_model": {d.model_label: d.batch_size for d in self._deployments}},
1196
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/datacenter/workloads/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- """Datacenter workloads.
2
-
3
- LLM inference workloads and training workloads.
4
- """
 
 
 
 
 
openg2g/datacenter/workloads/inference.py DELETED
@@ -1,1363 +0,0 @@
1
- """Inference workload: power traces, templates, ITL fits, and augmentation."""
2
-
3
- from __future__ import annotations
4
-
5
- import json
6
- import logging
7
- from collections.abc import Sequence
8
- from dataclasses import dataclass, field
9
- from fractions import Fraction
10
- from pathlib import Path
11
- from typing import Any, cast
12
-
13
- import numpy as np
14
- import pandas as pd
15
- from mlenergy_data.modeling import ITLMixtureModel
16
- from mlenergy_data.records import LLMRuns
17
- from pydantic import BaseModel, ConfigDict
18
-
19
- import openg2g
20
- from openg2g.common import ThreePhase
21
- from openg2g.datacenter.config import InferenceModelSpec
22
- from openg2g.datacenter.layout import ActivationPolicy, ServerLayout
23
-
24
- logger = logging.getLogger(__name__)
25
-
26
-
27
- class MLEnergySource(BaseModel):
28
- """Per-model ML.ENERGY benchmark data extraction settings.
29
-
30
- Attributes:
31
- model_label: Simulation label for the model.
32
- task: Benchmark task name (e.g. `"lm-arena-chat"`, `"gpqa"`).
33
- gpu: GPU model name (e.g. `"H100"`).
34
- batch_sizes: Batch sizes to extract from the benchmark data.
35
- fit_exclude_batch_sizes: Batch sizes to exclude from logistic
36
- curve fitting (but still included in trace extraction).
37
- """
38
-
39
- model_config = ConfigDict(frozen=True)
40
-
41
- model_label: str
42
- task: str
43
- gpu: str
44
- batch_sizes: tuple[int, ...]
45
- fit_exclude_batch_sizes: tuple[int, ...] = ()
46
-
47
-
48
- @dataclass(frozen=True)
49
- class InferenceTrace:
50
- """A single power trace measurement.
51
-
52
- Attributes:
53
- t_s: Time vector (seconds), monotonically increasing.
54
- power_w: Total power vector (watts) across all measured GPUs,
55
- same length as `t_s`.
56
- measured_gpus: Number of GPUs used in the measurement.
57
- """
58
-
59
- t_s: np.ndarray
60
- power_w: np.ndarray
61
- measured_gpus: int
62
-
63
- def __post_init__(self) -> None:
64
- if len(self.t_s) != len(self.power_w):
65
- raise ValueError(f"t_s and power_w must have the same length, got {len(self.t_s)} and {len(self.power_w)}")
66
- if len(self.t_s) < 5:
67
- raise ValueError("Trace too short (need at least 5 samples).")
68
- if self.measured_gpus < 1:
69
- raise ValueError(f"measured_gpus must be >= 1, got {self.measured_gpus}")
70
-
71
-
72
- def _build_per_gpu_power_template(
73
- trace: InferenceTrace,
74
- *,
75
- dt_s: Fraction | float,
76
- duration_s: Fraction | float,
77
- steady_skip_s: float = 0.0,
78
- ) -> np.ndarray:
79
- """Build a per-GPU power template over [0, duration_s] by periodic repetition.
80
-
81
- Args:
82
- trace: Source power trace (total power across measured GPUs).
83
- dt_s: Simulation timestep in seconds.
84
- duration_s: Total simulation duration in seconds.
85
- steady_skip_s: Skip this many seconds from the start of the trace
86
- to avoid warm-up transients.
87
-
88
- Returns:
89
- 1-D array of per-GPU power values at each simulation timestep.
90
- """
91
- trace_t = np.asarray(trace.t_s, float)
92
- trace_p_total = np.asarray(trace.power_w, float)
93
-
94
- mg = max(trace.measured_gpus, 1)
95
- p_per_gpu = trace_p_total / mg
96
- p_per_gpu = np.clip(p_per_gpu, 0.0, None)
97
-
98
- if steady_skip_s > 0.0:
99
- idx0 = np.searchsorted(trace_t, trace_t[0] + float(steady_skip_s))
100
- if idx0 < trace_t.size - 5:
101
- trace_t = trace_t[idx0:] - trace_t[idx0]
102
- p_per_gpu = p_per_gpu[idx0:]
103
-
104
- trace_t = trace_t - trace_t[0]
105
- period = float(trace_t[-1] - trace_t[0])
106
- if period <= 0:
107
- raise ValueError("Non-positive trace duration.")
108
-
109
- n_steps = int(np.ceil(float(duration_s) / float(dt_s))) + 1
110
- t_grid = np.arange(n_steps, dtype=float) * float(dt_s)
111
- t_mod = np.mod(t_grid, period)
112
-
113
- template = np.interp(t_mod, trace_t, p_per_gpu, left=p_per_gpu[0], right=p_per_gpu[-1])
114
- return np.clip(template, 0.0, None)
115
-
116
-
117
- class ITLFitStore:
118
- """Per-model, per-batch-size ITL mixture distributions.
119
-
120
- Indexed by `(model_label, batch_size)`. Provides:
121
-
122
- - [`load`][.load]: load fits from a CSV produced by the data pipeline
123
- - [`distributions`][.distributions]: access as a nested dict
124
- - [`sample_avg`][.sample_avg]: sample a fleet-average ITL value
125
-
126
- Attributes:
127
- COL_MODEL_LABEL: Column name for model label in the CSV.
128
- COL_BATCH_SIZE: Column name for batch size in the CSV.
129
- """
130
-
131
- COL_MODEL_LABEL = "model_label"
132
- COL_BATCH_SIZE = "max_num_seqs"
133
-
134
- def __init__(
135
- self,
136
- distributions: dict[str, dict[int, ITLMixtureModel]],
137
- approx_sampling_thresh: int = 30,
138
- ) -> None:
139
- self._distributions = {
140
- str(label): {int(b): m for b, m in per_batch.items()} for label, per_batch in distributions.items()
141
- }
142
- self._approx_sampling_thresh = int(approx_sampling_thresh)
143
-
144
- @property
145
- def distributions(self) -> dict[str, dict[int, ITLMixtureModel]]:
146
- """Nested dict: `model_label -> batch_size -> ITLMixtureModel`."""
147
- return self._distributions
148
-
149
- def sample_avg(
150
- self,
151
- model_label: str,
152
- batch_size: int,
153
- n_replicas: int,
154
- rng: np.random.Generator,
155
- ) -> float:
156
- """Sample a fleet-average ITL for the given model and batch size.
157
-
158
- Uses `ITLMixtureModel.sample_avg` under the hood, with the
159
- `approx_sampling_thresh` set at construction time.
160
-
161
- Args:
162
- model_label: Model label string.
163
- batch_size: Current batch size.
164
- n_replicas: Number of active replicas.
165
- rng: NumPy random generator for sampling.
166
-
167
- Returns:
168
- Fleet-average ITL in seconds.
169
-
170
- Raises:
171
- KeyError: If model or batch size is not in the store.
172
- """
173
- model_dists = self._distributions.get(model_label)
174
- if model_dists is None:
175
- raise KeyError(f"No ITL distributions for model={model_label!r}")
176
- params = model_dists.get(int(batch_size))
177
- if params is None:
178
- raise KeyError(
179
- f"No ITL distributions for model={model_label!r}, batch={batch_size}. "
180
- f"Available={sorted(model_dists.keys())}"
181
- )
182
- return params.sample_avg(
183
- n_replicas=n_replicas,
184
- rng=rng,
185
- exact_threshold=self._approx_sampling_thresh,
186
- )
187
-
188
- @classmethod
189
- def load(cls, csv_path: Path | str, approx_sampling_thresh: int = 30) -> ITLFitStore:
190
- """Load ITL mixture fits from a CSV.
191
-
192
- Expected columns: `model_label`, `max_num_seqs`, plus the
193
- `itl_mix_*` parameter columns produced by
194
- `ITLMixtureModel.to_dict()`.
195
-
196
- Args:
197
- csv_path: Path to the latency fits CSV.
198
- approx_sampling_thresh: Replica count above which sampling
199
- uses a CLT normal approximation instead of drawing
200
- individual samples.
201
- """
202
- csv_path = Path(csv_path)
203
- df = pd.read_csv(csv_path)
204
-
205
- required_cols = [cls.COL_MODEL_LABEL, cls.COL_BATCH_SIZE]
206
- missing = [c for c in required_cols if c not in df.columns]
207
- if missing:
208
- raise ValueError(f"{csv_path} missing columns: {missing}. Got: {list(df.columns)}")
209
-
210
- distributions: dict[str, dict[int, ITLMixtureModel]] = {}
211
- for row in df.to_dict(orient="records"):
212
- label = str(row[cls.COL_MODEL_LABEL]).strip()
213
- batch = int(row[cls.COL_BATCH_SIZE])
214
- distributions.setdefault(label, {})[batch] = ITLMixtureModel.from_dict(row)
215
-
216
- if not distributions:
217
- raise ValueError(f"No ITL mixture rows loaded from {csv_path}")
218
- return cls(distributions, approx_sampling_thresh=approx_sampling_thresh)
219
-
220
- def save(self, csv_path: Path) -> None:
221
- """Save ITL mixture fits to a CSV.
222
-
223
- Args:
224
- csv_path: Output CSV path.
225
- """
226
- csv_path = Path(csv_path)
227
- csv_path.parent.mkdir(parents=True, exist_ok=True)
228
- rows: list[dict[str, Any]] = []
229
- for label in sorted(self._distributions):
230
- for batch in sorted(self._distributions[label]):
231
- model = self._distributions[label][batch]
232
- rows.append(
233
- {
234
- self.COL_MODEL_LABEL: label,
235
- self.COL_BATCH_SIZE: batch,
236
- "itl_dist": "lognormal_mixture_2",
237
- **{f"itl_mix_{k}": v for k, v in model.to_dict().items()},
238
- }
239
- )
240
- pd.DataFrame(rows).to_csv(csv_path, index=False)
241
-
242
-
243
- class InferenceTemplateStore:
244
- """Pre-built per-GPU power templates for a specific simulation config.
245
-
246
- Created by [`InferenceTraceStore.build_templates`][..InferenceTraceStore.build_templates].
247
- Use [`template`][.template] to look up a template by model label and batch size.
248
- """
249
-
250
- def __init__(
251
- self,
252
- templates: dict[tuple[str, int], np.ndarray],
253
- batch_sizes_by_model: dict[str, list[int]],
254
- ) -> None:
255
- self._templates = templates
256
- self._batch_sizes_by_model = batch_sizes_by_model
257
-
258
- def template(self, model_label: str, batch_size: int) -> np.ndarray:
259
- """Return a pre-built per-GPU power template."""
260
- key = (str(model_label), int(batch_size))
261
- if key not in self._templates:
262
- raise KeyError(f"No template for model={model_label!r}, batch={batch_size}.")
263
- return self._templates[key]
264
-
265
- def batch_sizes(self, model_label: str) -> list[int]:
266
- """List of batch sizes available for a model."""
267
- sizes = self._batch_sizes_by_model.get(model_label)
268
- if sizes is None:
269
- raise KeyError(f"Unknown model: {model_label!r}")
270
- return list(sizes)
271
-
272
-
273
- class InferenceTraceStore:
274
- """Manages raw power traces loaded from CSV files.
275
-
276
- Indexed by `(model_label, batch_size)`. Provides:
277
-
278
- - [`load`][.load]: load traces discovered via a manifest CSV
279
- - [`build_templates`][.build_templates]: build per-GPU power
280
- templates for a specific simulation config, returning a
281
- [`InferenceTemplateStore`][..InferenceTemplateStore]
282
- """
283
-
284
- MANIFEST_COL_MODEL_LABEL = "model_label"
285
- MANIFEST_COL_NUM_GPUS = "num_gpus"
286
- MANIFEST_COL_BATCH_SIZE = "max_num_seqs"
287
- MANIFEST_COL_TRACE_FILE = "trace_file"
288
- TRACE_COL_TIME = "relative_time_s"
289
- TRACE_COL_POWER = "power_total_W"
290
-
291
- def __init__(self, traces: dict[str, dict[int, InferenceTrace]]) -> None:
292
- self._traces = {str(label): {int(b): tr for b, tr in per_batch.items()} for label, per_batch in traces.items()}
293
-
294
- @classmethod
295
- def load(cls, manifest: Path) -> InferenceTraceStore:
296
- """Load traces discovered via a manifest CSV.
297
-
298
- Trace file paths in the manifest are resolved relative to the
299
- manifest file's parent directory.
300
-
301
- Args:
302
- manifest: Path to the manifest CSV (e.g. `traces_summary.csv`).
303
- Expected columns: `model_label`, `num_gpus`, `max_num_seqs`,
304
- `trace_file`.
305
- """
306
- manifest = Path(manifest)
307
- base_dir = manifest.parent
308
- df = pd.read_csv(manifest)
309
-
310
- required_cols = [
311
- cls.MANIFEST_COL_MODEL_LABEL,
312
- cls.MANIFEST_COL_NUM_GPUS,
313
- cls.MANIFEST_COL_BATCH_SIZE,
314
- cls.MANIFEST_COL_TRACE_FILE,
315
- ]
316
- missing = [c for c in required_cols if c not in df.columns]
317
- if missing:
318
- raise ValueError(f"Manifest {manifest} missing columns: {missing}. Got: {list(df.columns)}")
319
-
320
- traces: dict[str, dict[int, InferenceTrace]] = {}
321
- for row in df.to_dict(orient="records"):
322
- label = str(row[cls.MANIFEST_COL_MODEL_LABEL])
323
- num_gpus = int(row[cls.MANIFEST_COL_NUM_GPUS])
324
- batch = int(row[cls.MANIFEST_COL_BATCH_SIZE])
325
- trace_path = base_dir / str(row[cls.MANIFEST_COL_TRACE_FILE])
326
-
327
- if not trace_path.exists():
328
- raise FileNotFoundError(f"Trace file not found: {trace_path} (model={label}, batch={batch})")
329
-
330
- tdf = pd.read_csv(trace_path)
331
- if cls.TRACE_COL_TIME not in tdf.columns or cls.TRACE_COL_POWER not in tdf.columns:
332
- raise ValueError(
333
- f"{trace_path} must contain {cls.TRACE_COL_TIME!r} and "
334
- f"{cls.TRACE_COL_POWER!r}. Got: {list(tdf.columns)}"
335
- )
336
-
337
- t = tdf[cls.TRACE_COL_TIME].to_numpy(float)
338
- p = tdf[cls.TRACE_COL_POWER].to_numpy(float)
339
- if np.any(np.diff(t) < 0):
340
- idx = np.argsort(t)
341
- t, p = t[idx], p[idx]
342
-
343
- traces.setdefault(label, {})[batch] = InferenceTrace(
344
- t_s=t,
345
- power_w=p,
346
- measured_gpus=num_gpus,
347
- )
348
-
349
- return cls(traces)
350
-
351
- def build_templates(
352
- self,
353
- *,
354
- duration_s: Fraction | float,
355
- dt_s: Fraction | float,
356
- steady_skip_s: float = 0.0,
357
- ) -> InferenceTemplateStore:
358
- """Build per-GPU power templates for all traces.
359
-
360
- Args:
361
- duration_s: Total simulation duration (seconds).
362
- dt_s: Simulation timestep (seconds).
363
- steady_skip_s: Skip this many seconds from the start of each
364
- trace to avoid warm-up transients.
365
-
366
- Returns:
367
- A [`InferenceTemplateStore`][openg2g.datacenter.workloads.inference.InferenceTemplateStore]
368
- holding the built templates.
369
- """
370
- templates: dict[tuple[str, int], np.ndarray] = {}
371
- batch_sizes_by_model: dict[str, list[int]] = {}
372
- for label, per_batch in self._traces.items():
373
- batch_sizes_by_model[label] = sorted(per_batch.keys())
374
- for batch, tr in per_batch.items():
375
- tpl = _build_per_gpu_power_template(
376
- tr,
377
- dt_s=dt_s,
378
- duration_s=duration_s,
379
- steady_skip_s=steady_skip_s,
380
- )
381
- templates[(label, batch)] = tpl
382
- return InferenceTemplateStore(templates, batch_sizes_by_model)
383
-
384
- def save(self, out_dir: Path) -> None:
385
- """Save traces and manifest CSV to a directory.
386
-
387
- Writes individual trace CSVs to `out_dir/traces/` and a manifest
388
- CSV at `out_dir/traces_summary.csv`.
389
-
390
- Args:
391
- out_dir: Output directory.
392
- """
393
- out_dir = Path(out_dir)
394
- traces_dir = out_dir / "traces"
395
- traces_dir.mkdir(parents=True, exist_ok=True)
396
-
397
- summary_rows: list[dict[str, Any]] = []
398
- for label in sorted(self._traces):
399
- for batch in sorted(self._traces[label]):
400
- tr = self._traces[label][batch]
401
- trace_name = f"{label}_num_gpus_{tr.measured_gpus}_max_num_seqs_{batch}.csv"
402
- pd.DataFrame(
403
- {
404
- self.TRACE_COL_TIME: tr.t_s,
405
- self.TRACE_COL_POWER: tr.power_w,
406
- }
407
- ).to_csv(traces_dir / trace_name, index=False)
408
- summary_rows.append(
409
- {
410
- self.MANIFEST_COL_MODEL_LABEL: label,
411
- self.MANIFEST_COL_NUM_GPUS: tr.measured_gpus,
412
- self.MANIFEST_COL_BATCH_SIZE: batch,
413
- self.MANIFEST_COL_TRACE_FILE: f"traces/{trace_name}",
414
- }
415
- )
416
- pd.DataFrame(summary_rows).to_csv(out_dir / "traces_summary.csv", index=False)
417
-
418
-
419
- class InferenceData:
420
- """LLM inference workload with offline simulation data.
421
-
422
- Bundles model specifications with power templates and latency
423
- distributions. Validates that all models have matching data entries.
424
-
425
- Args:
426
- models: Model specifications as a tuple of
427
- [`InferenceModelSpec`][openg2g.datacenter.config.InferenceModelSpec].
428
- power_templates: Pre-built per-GPU power templates for all models
429
- and batch sizes, created via
430
- [`InferenceTraceStore.build_templates`][..InferenceTraceStore.build_templates].
431
- itl_fits: Per-model ITL mixture distributions. Required when using
432
- controllers that read observed latency (e.g.,
433
- `OFOBatchSizeController`). When omitted, NaN is reported for
434
- observed latency.
435
- """
436
-
437
- def __init__(
438
- self,
439
- models: tuple[InferenceModelSpec, ...],
440
- *,
441
- power_templates: InferenceTemplateStore,
442
- itl_fits: ITLFitStore | None = None,
443
- ) -> None:
444
- if isinstance(power_templates, InferenceTraceStore):
445
- raise TypeError(
446
- "Expected a InferenceTemplateStore, got InferenceTraceStore. "
447
- "Call InferenceTraceStore.build_templates() first to create a InferenceTemplateStore."
448
- )
449
- if not models:
450
- raise ValueError("models must not be empty.")
451
- labels = [ms.model_label for ms in models]
452
- if len(labels) != len(set(labels)):
453
- raise ValueError(f"Duplicate model labels: {labels}")
454
-
455
- self._models = models
456
- self._power_templates: InferenceTemplateStore | None = power_templates
457
- self._trace_store: InferenceTraceStore | None = None
458
- self._itl_fit_store: ITLFitStore | None = None
459
- self._itl_fits = itl_fits
460
- self._itl_samples_df: pd.DataFrame | None = None
461
-
462
- for ms in self._models:
463
- try:
464
- power_templates.batch_sizes(ms.model_label)
465
- except KeyError:
466
- raise ValueError(
467
- f"Power templates missing for model {ms.model_label!r}. "
468
- f"Ensure InferenceTraceStore contains traces for all models."
469
- ) from None
470
-
471
- if itl_fits is not None and ms.model_label not in itl_fits.distributions:
472
- raise ValueError(
473
- f"ITL fits missing for model {ms.model_label!r}. "
474
- f"Available models in ITLFitStore: {sorted(itl_fits.distributions.keys())}"
475
- )
476
-
477
- @classmethod
478
- def generate(
479
- cls,
480
- models: tuple[InferenceModelSpec, ...],
481
- data_sources: dict[str, MLEnergySource],
482
- *,
483
- runs: Any = None,
484
- mlenergy_data_dir: Path | None = None,
485
- dt_s: float = 0.1,
486
- seed: int = 0,
487
- itl_sample_cap: int = 2048,
488
- ) -> InferenceData:
489
- """Generate inference data from ML.ENERGY benchmark data.
490
-
491
- Produces power traces and ITL mixture fits for all models and
492
- batch sizes specified in `data_sources`.
493
-
494
- Args:
495
- models: Model specifications.
496
- data_sources: Per-model benchmark data extraction settings,
497
- keyed by `model_label`.
498
- runs: Pre-loaded `LLMRuns` object. If `None`, loads from
499
- `mlenergy_data_dir` or the HuggingFace Hub.
500
- mlenergy_data_dir: Path to compiled mlenergy-data directory.
501
- Ignored if `runs` is provided.
502
- dt_s: Trace timestep (seconds).
503
- seed: Random seed for ITL fitting.
504
- itl_sample_cap: Maximum ITL samples per run for fitting.
505
-
506
- Returns:
507
- A new `InferenceData` with generated traces and ITL fits (no
508
- templates — call `InferenceTraceStore.build_templates()` on the
509
- saved/loaded store to get templates).
510
- """
511
- if runs is None:
512
- unique_tasks = {src.task for src in data_sources.values()}
513
- if mlenergy_data_dir:
514
- logger.info("Loading runs from %s (tasks: %s)", mlenergy_data_dir, sorted(unique_tasks))
515
- runs = LLMRuns.from_directory(str(mlenergy_data_dir), stable_only=False).task(*unique_tasks)
516
- else:
517
- logger.info("Loading runs from Hugging Face Hub (tasks: %s)", sorted(unique_tasks))
518
- runs = LLMRuns.from_hf(stable_only=False).task(*unique_tasks)
519
- if not runs:
520
- raise ValueError("No runs found for the specified tasks")
521
-
522
- subsets_by_label: dict[str, Any] = {}
523
- tl_frames: list[pd.DataFrame] = []
524
- itl_frames: list[pd.DataFrame] = []
525
-
526
- for ms in models:
527
- src = data_sources.get(ms.model_label)
528
- if src is None:
529
- raise ValueError(f"No data source for model {ms.model_label!r}")
530
- model_id = ms.model_id
531
- if not model_id:
532
- raise ValueError(f"model_id is required for data generation (model={ms.model_label!r})")
533
-
534
- subset = (
535
- runs.model_id(model_id).gpu_model(src.gpu).num_gpus(ms.gpus_per_replica).max_num_seqs(*src.batch_sizes)
536
- )
537
- if not subset:
538
- raise ValueError(
539
- f"Config matched zero runs: model_id={model_id!r}, "
540
- f"gpu={src.gpu!r}, num_gpus={ms.gpus_per_replica}, "
541
- f"batch_sizes={src.batch_sizes}"
542
- )
543
- subsets_by_label[ms.model_label] = subset
544
- logger.info(
545
- "%s: %d runs (model_id=%s, gpu=%s, num_gpus=%d, batches=%s)",
546
- ms.model_label,
547
- len(subset),
548
- model_id,
549
- src.gpu,
550
- ms.gpus_per_replica,
551
- sorted({r.max_num_seqs for r in subset}),
552
- )
553
-
554
- logger.info("Downloading raw result files for %d models ...", len(subsets_by_label))
555
- for subset in subsets_by_label.values():
556
- subset.download_raw_files(file="results")
557
- logger.info("Downloads complete. Extracting timelines and ITL samples ...")
558
-
559
- for label, subset in subsets_by_label.items():
560
- for run in subset:
561
- tl = run.timelines(metric="power.device_instant")
562
- tl["model_label"] = label
563
- tl["num_gpus"] = run.num_gpus
564
- tl["max_num_seqs"] = run.max_num_seqs
565
- tl["run_index"] = len(tl_frames)
566
- tl_frames.append(tl)
567
-
568
- itl = subset.inter_token_latencies()
569
- itl["model_label"] = label
570
- itl_frames.append(itl)
571
-
572
- all_tl = pd.concat(tl_frames, ignore_index=True)
573
- itl_samples_df = pd.concat(itl_frames, ignore_index=True)
574
- logger.info("Building trace store (%d timeline rows) and fitting ITL models ...", len(all_tl))
575
-
576
- trace_store = _build_trace_store_from_timelines(all_tl, dt_s=dt_s)
577
- itl_fit_store = _build_itl_fit_store(itl_samples_df, max_samples=itl_sample_cap, seed=seed)
578
-
579
- return cls._from_stores(
580
- models,
581
- trace_store=trace_store,
582
- itl_fit_store=itl_fit_store,
583
- itl_samples_df=itl_samples_df,
584
- )
585
-
586
- @classmethod
587
- def _from_stores(
588
- cls,
589
- models: tuple[InferenceModelSpec, ...],
590
- *,
591
- trace_store: InferenceTraceStore,
592
- itl_fit_store: ITLFitStore,
593
- itl_samples_df: pd.DataFrame | None = None,
594
- ) -> InferenceData:
595
- """Create from raw stores (internal, used by generate)."""
596
- instance = object.__new__(cls)
597
- instance._models = models
598
- instance._trace_store = trace_store
599
- instance._itl_fit_store = itl_fit_store
600
- instance._power_templates = None
601
- instance._itl_fits = itl_fit_store
602
- instance._itl_samples_df = itl_samples_df
603
- return instance
604
-
605
- def save(self, out_dir: Path, *, plot: bool = False) -> None:
606
- """Save traces and ITL fits to a directory.
607
-
608
- Args:
609
- out_dir: Output directory.
610
- plot: If `True`, also write characterization plots (power
611
- trajectories, ITL distributions).
612
- """
613
- out_dir = Path(out_dir)
614
- out_dir.mkdir(parents=True, exist_ok=True)
615
- if self._trace_store is not None:
616
- self._trace_store.save(out_dir)
617
- if self._itl_fits is not None:
618
- self._itl_fits.save(out_dir / "latency_fits.csv")
619
-
620
- (out_dir / "_manifest.json").write_text(
621
- json.dumps({"openg2g_version": openg2g.__version__}, indent=2, sort_keys=True)
622
- )
623
-
624
- if plot and self._trace_store is not None:
625
- _plot_power_trajectories(self._trace_store, self._models, out_dir)
626
- itl_samples = self._itl_samples_df
627
- if self._itl_fit_store is not None and itl_samples is not None:
628
- for ms in self._models:
629
- _plot_itl_distributions(self._itl_fit_store, itl_samples, ms.model_label, out_dir)
630
-
631
- @classmethod
632
- def load(
633
- cls,
634
- data_dir: Path,
635
- models: tuple[InferenceModelSpec, ...],
636
- *,
637
- duration_s: float = 600.0,
638
- dt_s: float = 0.1,
639
- steady_skip_s: float = 0.0,
640
- ) -> InferenceData:
641
- """Load from a generated data directory.
642
-
643
- Loads traces from `traces_summary.csv`, builds templates, and
644
- loads ITL fits from `latency_fits.csv`.
645
-
646
- Args:
647
- data_dir: Directory containing generated data.
648
- models: Model specifications.
649
- duration_s: Simulation duration for template building.
650
- dt_s: Simulation timestep for template building.
651
- steady_skip_s: Skip seconds for template building.
652
- """
653
- data_dir = Path(data_dir)
654
- _check_version_stamp(data_dir, "InferenceData")
655
- store = InferenceTraceStore.load(data_dir / "traces_summary.csv")
656
- templates = store.build_templates(duration_s=duration_s, dt_s=dt_s, steady_skip_s=steady_skip_s)
657
- itl_fits = ITLFitStore.load(data_dir / "latency_fits.csv")
658
- return cls(models, power_templates=templates, itl_fits=itl_fits)
659
-
660
- @classmethod
661
- def ensure(
662
- cls,
663
- data_dir: Path,
664
- models: tuple[InferenceModelSpec, ...],
665
- data_sources: dict[str, MLEnergySource] | None = None,
666
- *,
667
- mlenergy_data_dir: Path | None = None,
668
- plot: bool = False,
669
- duration_s: float = 600.0,
670
- dt_s: float = 0.1,
671
- steady_skip_s: float = 0.0,
672
- ) -> InferenceData:
673
- """Load from `data_dir`, generating first if needed.
674
-
675
- If `data_dir/traces_summary.csv` does not exist, generates
676
- inference data from ML.ENERGY benchmark data and saves it.
677
- Then loads and returns.
678
-
679
- Args:
680
- data_dir: Data directory (generated files go here).
681
- models: Model specifications.
682
- data_sources: Per-model benchmark data extraction settings,
683
- keyed by `model_label`. Required when no cached data exists.
684
- mlenergy_data_dir: Path to compiled mlenergy-data directory.
685
- plot: If `True`, generate characterization plots on generation.
686
- duration_s: Simulation duration for template building.
687
- dt_s: Simulation timestep for template building.
688
- steady_skip_s: Skip seconds for template building.
689
- """
690
- data_dir = Path(data_dir)
691
- if not (data_dir / "traces_summary.csv").exists():
692
- if data_sources is None:
693
- raise ValueError("data_sources required for InferenceData generation (no cached data)")
694
- logger.info("Generating inference data to %s ...", data_dir)
695
- cls.generate(
696
- models,
697
- data_sources,
698
- mlenergy_data_dir=mlenergy_data_dir,
699
- dt_s=dt_s,
700
- ).save(data_dir, plot=plot)
701
- return cls.load(data_dir, models, duration_s=duration_s, dt_s=dt_s, steady_skip_s=steady_skip_s)
702
-
703
- @property
704
- def models(self) -> tuple[InferenceModelSpec, ...]:
705
- """The model specifications."""
706
- return self._models
707
-
708
- @property
709
- def power_templates(self) -> InferenceTemplateStore:
710
- if self._power_templates is None:
711
- raise RuntimeError("power_templates not available (generate-only instance). Load from disk first.")
712
- return self._power_templates
713
-
714
- @property
715
- def itl_fits(self) -> ITLFitStore | None:
716
- return self._itl_fits
717
-
718
-
719
- def _check_version_stamp(data_dir: Path, label: str) -> None:
720
- """Log a warning if cached data was generated with a different openg2g version."""
721
- manifest_path = data_dir / "_manifest.json"
722
- if not manifest_path.exists():
723
- return
724
- try:
725
- manifest = json.loads(manifest_path.read_text())
726
- except (json.JSONDecodeError, OSError):
727
- return
728
- cached_version = manifest.get("openg2g_version", "unknown")
729
- if cached_version != openg2g.__version__:
730
- logger.warning(
731
- "%s: cached data generated with openg2g %s (current %s). Consider regenerating.",
732
- label,
733
- cached_version,
734
- openg2g.__version__,
735
- )
736
-
737
-
738
- def _build_trace_store_from_timelines(tl: pd.DataFrame, *, dt_s: float) -> InferenceTraceStore:
739
- """Build an InferenceTraceStore from raw timeline data.
740
-
741
- Args:
742
- tl: Combined timeline dataframe with columns `model_label`,
743
- `num_gpus`, `max_num_seqs`, `run_index`, `relative_time_s`, `value`.
744
- dt_s: Resampling timestep.
745
-
746
- Returns:
747
- An InferenceTraceStore with median-aggregated traces.
748
- """
749
- if tl.empty:
750
- raise ValueError("No timeline rows extracted from selected runs")
751
-
752
- traces: dict[str, dict[int, InferenceTrace]] = {}
753
- keys = [
754
- InferenceTraceStore.MANIFEST_COL_MODEL_LABEL,
755
- InferenceTraceStore.MANIFEST_COL_NUM_GPUS,
756
- InferenceTraceStore.MANIFEST_COL_BATCH_SIZE,
757
- ]
758
- for key, g in tl.groupby(keys, dropna=False):
759
- if not isinstance(key, tuple):
760
- raise TypeError(f"Expected tuple groupby key, got {type(key).__name__}")
761
- model_label, num_gpus, batch = str(key[0]), cast(int, key[1]), cast(int, key[2])
762
- series_list: list[tuple[np.ndarray, np.ndarray]] = []
763
- t_ends: list[float] = []
764
-
765
- for _run_index, rg in g.groupby("run_index"):
766
- rr = rg.sort_values("relative_time_s")
767
- t = rr["relative_time_s"].to_numpy(dtype=float)
768
- p = rr["value"].to_numpy(dtype=float)
769
- if t.size < 2:
770
- continue
771
- series_list.append((t, p))
772
- t_ends.append(float(t[-1]))
773
-
774
- if not series_list:
775
- continue
776
-
777
- t_end = float(np.median(np.asarray(t_ends, dtype=float)))
778
- grid = np.arange(0.0, t_end + 1e-12, float(dt_s), dtype=float)
779
- mats: list[np.ndarray] = []
780
- for t, p in series_list:
781
- mats.append(np.interp(grid, t, p, left=p[0], right=p[-1]))
782
- mat = np.vstack(mats)
783
- p_med = np.median(mat, axis=0)
784
-
785
- traces.setdefault(model_label, {})[batch] = InferenceTrace(
786
- t_s=grid,
787
- power_w=p_med,
788
- measured_gpus=int(num_gpus),
789
- )
790
-
791
- if not traces:
792
- raise ValueError("No trace profiles extracted from timeline data")
793
- return InferenceTraceStore(traces)
794
-
795
-
796
- def _build_itl_fit_store(
797
- itl: pd.DataFrame,
798
- *,
799
- max_samples: int,
800
- seed: int,
801
- ) -> ITLFitStore:
802
- """Build an ITLFitStore from raw ITL sample data.
803
-
804
- Args:
805
- itl: ITL sample dataframe with columns `model_label`, `num_gpus`,
806
- `max_num_seqs`, `itl_s`.
807
- max_samples: Maximum ITL samples per group for fitting.
808
- seed: Random seed for ITL fitting.
809
-
810
- Returns:
811
- An ITLFitStore with fitted mixture distributions.
812
- """
813
- if itl.empty:
814
- raise ValueError("No ITL samples provided")
815
-
816
- distributions: dict[str, dict[int, ITLMixtureModel]] = {}
817
- for key, g in itl.groupby(["model_label", "num_gpus", "max_num_seqs"], dropna=False):
818
- if not isinstance(key, tuple):
819
- raise TypeError(f"Expected tuple groupby key, got {type(key).__name__}")
820
- model_label, _num_gpus, batch = str(key[0]), cast(int, key[1]), cast(int, key[2])
821
- fit = ITLMixtureModel.fit(
822
- g["itl_s"].to_numpy(dtype=float),
823
- max_samples=max_samples,
824
- seed=seed,
825
- )
826
- distributions.setdefault(model_label, {})[batch] = fit
827
-
828
- if not distributions:
829
- raise ValueError("No ITL fits produced")
830
- return ITLFitStore(distributions)
831
-
832
-
833
- @dataclass(frozen=True)
834
- class InferenceAugmentedPower:
835
- """Result of inference power augmentation for one simulation timestep.
836
-
837
- Attributes:
838
- power_w: Three-phase inference power (watts), excluding base load.
839
- power_by_model_w: Per-model total active power (watts).
840
- active_replicas_by_model: Per-model active replica count.
841
- """
842
-
843
- power_w: ThreePhase
844
- power_by_model_w: dict[str, float] = field(default_factory=dict)
845
- active_replicas_by_model: dict[str, int] = field(default_factory=dict)
846
-
847
-
848
- class InferencePowerAugmenter:
849
- """Scales per-GPU inference power through server layouts to three-phase power.
850
-
851
- Given per-GPU power values for each server (one value per server per
852
- model), applies per-server scaling, noise, activation masking, and
853
- phase summation to produce inference-level three-phase power.
854
-
855
- This class is backend-agnostic. The offline datacenter feeds it
856
- template-indexed values; the online datacenter can feed it
857
- live-measured values. The datacenter backend is responsible for
858
- adding facility base load on top of the returned inference power.
859
-
860
- Args:
861
- layouts: Per-model server layouts (physical topology).
862
- policies: Per-model activation policies determining which servers
863
- are active at each timestep.
864
- seed: Random seed for noise RNG.
865
- """
866
-
867
- def __init__(
868
- self,
869
- layouts: dict[str, ServerLayout],
870
- policies: dict[str, ActivationPolicy],
871
- seed: int = 0,
872
- ) -> None:
873
- self._layouts = layouts
874
- self._policies = policies
875
- self._seed = int(seed)
876
- self._rng = np.random.default_rng(self._seed)
877
-
878
- def augment(
879
- self,
880
- per_gpu_by_model: dict[str, np.ndarray],
881
- t: float,
882
- ) -> InferenceAugmentedPower:
883
- """Augment per-server per-GPU power to three-phase power.
884
-
885
- Args:
886
- per_gpu_by_model: Mapping of model label to per-GPU power
887
- array of shape `(num_servers,)`. Only models with active
888
- replicas should be included.
889
- t: Current simulation time (seconds).
890
-
891
- Returns:
892
- Augmented inference power with three-phase totals, per-model
893
- power, and per-model active replica counts.
894
- """
895
- phase_power = np.zeros(3, dtype=float)
896
- power_by_model: dict[str, float] = {}
897
- active_replicas_by_model: dict[str, int] = {}
898
-
899
- for label, per_gpu in per_gpu_by_model.items():
900
- layout = self._layouts[label]
901
- policy = self._policies[label]
902
-
903
- server_powers = per_gpu * layout.gpus_per_server_list * layout.amplitude_scales
904
- if layout.noise_fraction > 0:
905
- levels = np.maximum(server_powers, 1.0)
906
- server_powers += self._rng.normal(0.0, 1.0, size=layout.num_servers) * layout.noise_fraction * levels
907
- server_powers = np.maximum(server_powers, 0.0)
908
-
909
- active_indices = policy.active_indices(t)
910
- active_powers = server_powers[active_indices]
911
- active_phases = layout.phase_list[active_indices]
912
-
913
- model_phase_power = np.zeros(3, dtype=float)
914
- np.add.at(model_phase_power, active_phases, active_powers)
915
- phase_power += model_phase_power
916
-
917
- power_by_model[label] = float(np.sum(active_powers))
918
- active_gpus = int(np.sum(layout.gpus_per_server_list[active_indices]))
919
- active_replicas_by_model[label] = active_gpus // layout.gpus_per_replica
920
-
921
- return InferenceAugmentedPower(
922
- power_w=ThreePhase(
923
- a=float(phase_power[0]),
924
- b=float(phase_power[1]),
925
- c=float(phase_power[2]),
926
- ),
927
- power_by_model_w=power_by_model,
928
- active_replicas_by_model=active_replicas_by_model,
929
- )
930
-
931
- def reset(self) -> None:
932
- """Re-seed the noise RNG to its initial state."""
933
- self._rng = np.random.default_rng(self._seed)
934
-
935
-
936
- def _lognorm_pdf(x: np.ndarray, sigma: float, scale: float) -> np.ndarray:
937
- """Standard lognormal PDF: f(x; sigma, scale) for x > 0."""
938
- x = np.asarray(x, dtype=float)
939
- out = np.zeros_like(x)
940
- mask = x > 0
941
- xx = x[mask]
942
- out[mask] = (1.0 / (xx * sigma * np.sqrt(2.0 * np.pi))) * np.exp(-(np.log(xx / scale) ** 2) / (2.0 * sigma * sigma))
943
- return out
944
-
945
-
946
- def _plot_power_trajectories(
947
- trace_store: InferenceTraceStore,
948
- models: tuple[InferenceModelSpec, ...],
949
- out_dir: Path,
950
- *,
951
- rolling_window: int = 10,
952
- ) -> None:
953
- """Plot total GPU power trajectories per batch size.
954
-
955
- One subplot per model. Each curve is a different batch size.
956
- Saves to `out_dir / "power_trajectories.png"`.
957
- """
958
- import matplotlib.pyplot as plt
959
-
960
- model_labels = [m.model_label for m in models]
961
- n_models = len(model_labels)
962
- fig, axes = plt.subplots(n_models, 1, figsize=(10, 5), dpi=160, squeeze=False)
963
-
964
- panel_labels = "abcdefghij"
965
-
966
- for row, model_label in enumerate(model_labels):
967
- ax = axes[row, 0]
968
- per_batch = trace_store._traces.get(model_label, {})
969
- if not per_batch:
970
- ax.set_title(f"{model_label} (no traces found)")
971
- continue
972
-
973
- batches = sorted(per_batch.keys())
974
- cmap = plt.get_cmap("tab10")
975
-
976
- for i, batch in enumerate(batches):
977
- tr = per_batch[batch]
978
- time_s = tr.t_s.copy()
979
- power_kw = tr.power_w.copy() / 1000.0
980
-
981
- if rolling_window > 1 and len(power_kw) >= rolling_window:
982
- kernel = np.ones(rolling_window) / rolling_window
983
- smoothed = np.convolve(power_kw, kernel, mode="same")
984
- half = rolling_window // 2
985
- smoothed[:half] = power_kw[:half]
986
- smoothed[-half:] = power_kw[-half:]
987
- power_kw = smoothed
988
-
989
- ax.plot(time_s, power_kw, label=f"batch={batch}", color=cmap(i))
990
-
991
- label_char = panel_labels[row] if row < len(panel_labels) else ""
992
- num_gpus = per_batch[batches[0]].measured_gpus
993
- gpu_suffix = "GPUs" if num_gpus > 1 else "GPU"
994
- ax.set_title(
995
- f"({label_char}) {model_label}: Total-GPU Power ({num_gpus} {gpu_suffix})",
996
- fontsize=13,
997
- )
998
- ax.set_ylabel("Power (kW)", fontsize=11)
999
- if row == 0:
1000
- ax.legend(fontsize=9, ncol=len(batches), loc="lower center", frameon=True, framealpha=0.9)
1001
- ax.grid(True, alpha=0.3)
1002
- ax.set_xlim(left=0)
1003
-
1004
- axes[-1, 0].set_xlabel("Time (seconds)", fontsize=11)
1005
- fig.tight_layout()
1006
-
1007
- save_path = out_dir / "power_trajectories.png"
1008
- save_path.parent.mkdir(parents=True, exist_ok=True)
1009
- fig.savefig(save_path, bbox_inches="tight", pad_inches=0.02)
1010
- plt.close(fig)
1011
- logger.info("Saved power trajectories plot to %s", save_path)
1012
-
1013
-
1014
- def _plot_itl_distributions(
1015
- itl_fit_store: ITLFitStore,
1016
- itl_samples_df: pd.DataFrame,
1017
- model_label: str,
1018
- out_dir: Path,
1019
- *,
1020
- hist_bins: int = 120,
1021
- hist_alpha: float = 0.12,
1022
- x_lo_q: float = 0.5,
1023
- x_hi_q: float = 99.5,
1024
- grid_n: int = 1200,
1025
- ) -> None:
1026
- """Plot ITL mixture distribution overlay for one model.
1027
-
1028
- Shows the fitted mixture PDF for each batch size overlaid, with
1029
- histograms and an inset showing steady/stall component decomposition
1030
- for the largest batch size. Saves to `out_dir / "itl_distributions_{model_label}.png"`.
1031
- """
1032
- import matplotlib.pyplot as plt
1033
- from mpl_toolkits.axes_grid1.inset_locator import inset_axes
1034
-
1035
- model_dists = itl_fit_store.distributions.get(model_label, {})
1036
- if not model_dists:
1037
- logger.warning("No ITL distributions for model %s, skipping plot", model_label)
1038
- return
1039
-
1040
- samples = itl_samples_df[itl_samples_df["model_label"] == model_label]
1041
- batches = sorted(model_dists.keys())
1042
-
1043
- all_x = samples[samples["max_num_seqs"].isin(batches)]["itl_s"].to_numpy(dtype=float)
1044
- if len(all_x) == 0:
1045
- logger.warning("No ITL samples for model %s, skipping plot", model_label)
1046
- return
1047
-
1048
- lo = float(np.percentile(all_x, x_lo_q))
1049
- hi = float(np.percentile(all_x, x_hi_q))
1050
- grid = np.linspace(lo, hi, grid_n)
1051
-
1052
- fig, ax = plt.subplots(figsize=(7.2, 3.2), dpi=160)
1053
-
1054
- cmap = plt.get_cmap("tab10") if len(batches) <= 10 else plt.get_cmap("tab20")
1055
- colors = {b: cmap(i % cmap.N) for i, b in enumerate(batches)}
1056
-
1057
- for b in batches:
1058
- model = model_dists[b]
1059
- params = model.to_dict()
1060
- loc = float(params["loc"])
1061
- pi = float(params["pi_steady"])
1062
- s1 = float(params["sigma_steady"])
1063
- sc1 = float(params["scale_steady"])
1064
- s2 = float(params["sigma_stall"])
1065
- sc2 = float(params["scale_stall"])
1066
-
1067
- shifted = grid - loc
1068
- pdf_mix = pi * _lognorm_pdf(shifted, s1, sc1) + (1 - pi) * _lognorm_pdf(shifted, s2, sc2)
1069
-
1070
- c = colors[b]
1071
- bsamp = samples[samples["max_num_seqs"] == b]["itl_s"].to_numpy(dtype=float)
1072
- if len(bsamp) > 0:
1073
- ax.hist(bsamp, bins=hist_bins, range=(lo, hi), density=True, alpha=hist_alpha, color=c)
1074
- ax.plot(grid, pdf_mix, linewidth=2.2, color=c, label=f"batch={b}")
1075
-
1076
- ax.set_title(f"(a) {model_label}: ITL distribution vs batch size")
1077
- ax.set_xlabel("Inter-token latency (seconds)")
1078
- ax.set_ylabel("Density")
1079
- ax.legend(ncol=4, fontsize=9, frameon=True)
1080
- ax.set_xlim(lo, hi)
1081
-
1082
- inset_batch = max(batches)
1083
- inset_model = model_dists[inset_batch]
1084
- inset_params = inset_model.to_dict()
1085
- loc = float(inset_params["loc"])
1086
- pi = float(inset_params["pi_steady"])
1087
- s1 = float(inset_params["sigma_steady"])
1088
- sc1 = float(inset_params["scale_steady"])
1089
- s2 = float(inset_params["sigma_stall"])
1090
- sc2 = float(inset_params["scale_stall"])
1091
-
1092
- bsamp = samples[samples["max_num_seqs"] == inset_batch]["itl_s"].to_numpy(dtype=float)
1093
- lo_i = float(np.percentile(bsamp, 0.5)) if len(bsamp) > 0 else lo
1094
- hi_i = float(np.percentile(bsamp, 99.5)) if len(bsamp) > 0 else hi
1095
- grid_i = np.linspace(lo_i, hi_i, 600)
1096
-
1097
- shifted_i = grid_i - loc
1098
- pdf_steady = pi * _lognorm_pdf(shifted_i, s1, sc1)
1099
- pdf_stall = (1 - pi) * _lognorm_pdf(shifted_i, s2, sc2)
1100
- pdf_mix_i = pdf_steady + pdf_stall
1101
-
1102
- axins = inset_axes(
1103
- ax,
1104
- width="38%",
1105
- height="55%",
1106
- loc="lower right",
1107
- bbox_to_anchor=(-0.1, 0.1, 1, 1),
1108
- bbox_transform=ax.transAxes,
1109
- )
1110
-
1111
- if len(bsamp) > 0:
1112
- axins.hist(bsamp, bins=60, range=(lo_i, hi_i), density=True, alpha=0.20, color=colors[inset_batch])
1113
-
1114
- axins.plot(grid_i, pdf_mix_i, lw=2.0, color=colors[inset_batch], label="mixture")
1115
- axins.plot(grid_i, pdf_steady, lw=1.6, ls="--", color="0.25", label="steady")
1116
- axins.plot(grid_i, pdf_stall, lw=1.6, ls=":", color="0.25", label="stall")
1117
-
1118
- axins.set_title(f"(b) batch={inset_batch} components", fontsize=9)
1119
- axins.set_xlim(lo_i, hi_i)
1120
- axins.tick_params(axis="both", labelsize=8)
1121
- axins.grid(True, alpha=0.25)
1122
- axins.legend(fontsize=8, frameon=True, loc="upper right")
1123
-
1124
- fig.tight_layout()
1125
-
1126
- save_path = out_dir / f"itl_distributions_{model_label}.png"
1127
- save_path.parent.mkdir(parents=True, exist_ok=True)
1128
- fig.savefig(save_path, bbox_inches="tight", pad_inches=0.02)
1129
- plt.close(fig)
1130
- logger.info("Saved ITL distributions plot for %s to %s", model_label, save_path)
1131
-
1132
-
1133
- class RequestsConfig(BaseModel):
1134
- """Configuration for building per-model JSONL request files.
1135
-
1136
- Attributes:
1137
- dataset: Dataset to sample prompts from (`"gpqa"` or `"lm-arena-chat"`).
1138
- num_requests: Number of requests to sample per model.
1139
- max_completion_tokens: Maximum output tokens per request.
1140
- seed: Random seed for dataset shuffling and oversampling.
1141
- system_prompt: System prompt prepended to every request.
1142
- """
1143
-
1144
- model_config = ConfigDict(frozen=True)
1145
-
1146
- dataset: str = "lm-arena-chat"
1147
- num_requests: int = 1000
1148
- max_completion_tokens: int = 512
1149
- seed: int = 0
1150
- system_prompt: str = "You are a helpful AI assistant."
1151
-
1152
-
1153
- class RequestStore:
1154
- """Per-model request dicts for online load generation.
1155
-
1156
- Each model's requests are stored as a list of OpenAI Chat Completion
1157
- streaming request dicts, suitable for submission to a vLLM server.
1158
-
1159
- Attributes:
1160
- requests_by_model: Mapping from model label to request dicts.
1161
- """
1162
-
1163
- def __init__(self, requests_by_model: dict[str, list[dict]]) -> None:
1164
- self.requests_by_model = requests_by_model
1165
-
1166
- @classmethod
1167
- def generate(
1168
- cls,
1169
- models: Sequence[InferenceModelSpec],
1170
- config: RequestsConfig | None = None,
1171
- *,
1172
- extra_body_by_model: dict[str, dict] | None = None,
1173
- ) -> RequestStore:
1174
- """Sample prompts and build per-model request dicts.
1175
-
1176
- Requires `pip install datasets openai`.
1177
-
1178
- Args:
1179
- models: Model specifications. Uses `model_id` for the API
1180
- model field.
1181
- config: Request generation config. Uses defaults if `None`.
1182
- extra_body_by_model: Optional per-model extra fields merged
1183
- into every request dict (e.g. `chat_template_kwargs`).
1184
- Keyed by `model_label`.
1185
- """
1186
- import random as _random
1187
-
1188
- from datasets import load_dataset
1189
- from openai.types.chat import (
1190
- ChatCompletionAssistantMessageParam,
1191
- ChatCompletionContentPartTextParam,
1192
- ChatCompletionMessageParam,
1193
- ChatCompletionSystemMessageParam,
1194
- ChatCompletionUserMessageParam,
1195
- )
1196
- from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming
1197
-
1198
- if config is None:
1199
- config = RequestsConfig()
1200
-
1201
- def _text_part(text: str) -> ChatCompletionContentPartTextParam:
1202
- return ChatCompletionContentPartTextParam(type="text", text=text)
1203
-
1204
- def _prompt_to_messages(prompt: str | list[str]) -> list[ChatCompletionMessageParam]:
1205
- if isinstance(prompt, str):
1206
- return [ChatCompletionUserMessageParam(role="user", content=[_text_part(prompt)])]
1207
- msgs: list[ChatCompletionMessageParam] = [
1208
- ChatCompletionUserMessageParam(role="user", content=[_text_part(prompt[0])])
1209
- ]
1210
- for i, turn in enumerate(prompt[1:]):
1211
- if i % 2 == 0:
1212
- msgs.append(ChatCompletionAssistantMessageParam(role="assistant", content=[_text_part(turn)]))
1213
- else:
1214
- msgs.append(ChatCompletionUserMessageParam(role="user", content=[_text_part(turn)]))
1215
- return msgs
1216
-
1217
- def _maybe_oversample(items: list, target: int, seed: int) -> None:
1218
- if len(items) >= target:
1219
- return
1220
- rng = _random.Random(seed)
1221
- original = list(items)
1222
- while len(items) < target:
1223
- items.append(rng.choice(original))
1224
-
1225
- def _sample_lm_arena_chat(num_requests: int, seed: int) -> list[str | list[str]]:
1226
- data = load_dataset("lmarena-ai/arena-human-preference-100k", split="train").shuffle(seed=seed)
1227
- prompts: list[str | list[str]] = []
1228
- for item in data:
1229
- num_turns = item["turn"]
1230
- conversation = item["conversation_a"]
1231
- for turns in range(num_turns):
1232
- if len(prompts) >= num_requests:
1233
- break
1234
- messages: list[str] = []
1235
- for message in conversation[: 2 * turns + 1]:
1236
- messages.append(message["content"])
1237
- prompts.append(messages if len(messages) > 1 else messages[0])
1238
- if len(prompts) >= num_requests:
1239
- break
1240
- _maybe_oversample(prompts, num_requests, seed)
1241
- return prompts
1242
-
1243
- def _sample_gpqa(num_requests: int, seed: int) -> list[str | list[str]]:
1244
- data = load_dataset("Idavidrein/gpqa", "gpqa_extended", split="train", streaming=True).shuffle(seed=seed)
1245
- _random.seed(seed)
1246
- prompts: list[str | list[str]] = []
1247
- for item in data:
1248
- if len(prompts) >= num_requests:
1249
- break
1250
- choices = [
1251
- item["Incorrect Answer 1"].strip(),
1252
- item["Incorrect Answer 2"].strip(),
1253
- item["Incorrect Answer 3"].strip(),
1254
- item["Correct Answer"].strip(),
1255
- ]
1256
- _random.shuffle(choices)
1257
- question = item["Question"]
1258
- prompt = f"What is the correct answer to the following question: {question}\n\nChoices:"
1259
- for letter, choice in zip("ABCD", choices, strict=True):
1260
- prompt += f"\n({letter}) {choice}"
1261
- prompts.append(prompt)
1262
- _maybe_oversample(prompts, num_requests, seed)
1263
- return prompts
1264
-
1265
- samplers = {"lm-arena-chat": _sample_lm_arena_chat, "gpqa": _sample_gpqa}
1266
- sampler = samplers.get(config.dataset)
1267
- if sampler is None:
1268
- raise ValueError(f"Unknown dataset: {config.dataset!r}. Available: {sorted(samplers)}")
1269
-
1270
- extra = extra_body_by_model or {}
1271
-
1272
- requests_by_model: dict[str, list[dict]] = {}
1273
- for spec in models:
1274
- label = spec.model_label
1275
- model_id = spec.model_id
1276
-
1277
- logger.info("Sampling %d %s prompts for %s (%s)...", config.num_requests, config.dataset, label, model_id)
1278
- prompts = sampler(num_requests=config.num_requests, seed=config.seed)
1279
-
1280
- system_msgs: list[ChatCompletionMessageParam] = []
1281
- if config.system_prompt:
1282
- system_msgs.append(ChatCompletionSystemMessageParam(role="system", content=config.system_prompt))
1283
-
1284
- template = CompletionCreateParamsStreaming(
1285
- model=model_id,
1286
- messages=system_msgs,
1287
- max_completion_tokens=config.max_completion_tokens,
1288
- stream=True,
1289
- stream_options={"include_usage": True, "continuous_usage_stats": True},
1290
- )
1291
- if label in extra:
1292
- template.update(extra[label])
1293
-
1294
- reqs: list[dict] = []
1295
- for prompt in prompts:
1296
- request = dict(template)
1297
- request["messages"] = list(template["messages"]) + _prompt_to_messages(prompt)
1298
- reqs.append(request)
1299
- requests_by_model[label] = reqs
1300
-
1301
- return cls(requests_by_model)
1302
-
1303
- def save(self, out_dir: Path) -> None:
1304
- """Write per-model JSONL files to `out_dir`.
1305
-
1306
- Args:
1307
- out_dir: Output directory. Created if it doesn't exist.
1308
- """
1309
- out_dir = Path(out_dir)
1310
- out_dir.mkdir(parents=True, exist_ok=True)
1311
- for label, reqs in self.requests_by_model.items():
1312
- out_path = out_dir / f"{label}.jsonl"
1313
- with open(out_path, "w") as f:
1314
- for req in reqs:
1315
- f.write(json.dumps(req) + "\n")
1316
- logger.info("Wrote %d requests for %s to %s", len(reqs), label, out_path)
1317
-
1318
- @classmethod
1319
- def load(cls, out_dir: Path) -> RequestStore:
1320
- """Load per-model JSONL files from `out_dir`.
1321
-
1322
- Args:
1323
- out_dir: Directory containing `{model_label}.jsonl` files.
1324
- """
1325
- requests_by_model: dict[str, list[dict]] = {}
1326
- out_dir = Path(out_dir)
1327
- for path in sorted(out_dir.glob("*.jsonl")):
1328
- label = path.stem
1329
- reqs: list[dict] = []
1330
- with open(path) as f:
1331
- for line in f:
1332
- line = line.strip()
1333
- if line:
1334
- reqs.append(json.loads(line))
1335
- requests_by_model[label] = reqs
1336
- logger.info("Loaded %d requests for %s", len(reqs), label)
1337
- return cls(requests_by_model)
1338
-
1339
- @classmethod
1340
- def ensure(
1341
- cls,
1342
- out_dir: Path,
1343
- models: Sequence[InferenceModelSpec] | None = None,
1344
- config: RequestsConfig | None = None,
1345
- *,
1346
- extra_body_by_model: dict[str, dict] | None = None,
1347
- ) -> RequestStore:
1348
- """Load request files from `out_dir`, generating first if needed.
1349
-
1350
- Args:
1351
- out_dir: Directory for JSONL files.
1352
- models: Required if request files don't exist yet.
1353
- config: Request generation config. Uses defaults if `None`.
1354
- extra_body_by_model: Optional per-model extra fields for
1355
- request generation. Keyed by `model_label`.
1356
- """
1357
- out_dir = Path(out_dir)
1358
- if not out_dir.exists():
1359
- if models is None:
1360
- raise ValueError("models required (no cached request data)")
1361
- logger.info("Generating request files to %s ...", out_dir)
1362
- cls.generate(models, config, extra_body_by_model=extra_body_by_model).save(out_dir)
1363
- return cls.load(out_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/datacenter/workloads/training.py DELETED
@@ -1,200 +0,0 @@
1
- """Training workload: typed trace data and periodic overlay evaluation."""
2
-
3
- from __future__ import annotations
4
-
5
- import logging
6
- from dataclasses import dataclass
7
- from pathlib import Path
8
-
9
- import numpy as np
10
- import pandas as pd
11
- from pydantic import BaseModel, ConfigDict
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- class TrainingTraceParams(BaseModel):
17
- """Parameters for synthetic training-like power trace generation.
18
-
19
- Attributes:
20
- duration_s: Total duration (seconds).
21
- dt_s: Timestep (seconds).
22
- seed: Random seed.
23
- P_hi: High plateau power (W).
24
- P_lo: Low plateau power (W).
25
- sigma_hi: Noise std in high plateaus (W).
26
- sigma_lo: Noise std in low plateaus (W).
27
- seg_lo_range: Duration range for low segments (seconds).
28
- seg_hi_range: Duration range for high segments (seconds).
29
- dip_prob_per_sec: Expected brief dips per second.
30
- dip_depth_range: Depth range for brief dips (W below current level).
31
- dip_dur_range: Duration range for brief dips (seconds).
32
- smooth_window_s: Smoothing window width (seconds).
33
- ramp_s: Initial warm-up ramp duration (seconds).
34
- ramp_from: Power at ramp start (W).
35
- """
36
-
37
- model_config = ConfigDict(frozen=True)
38
-
39
- duration_s: float = 1000.0
40
- dt_s: float = 0.1
41
- seed: int = 2
42
- P_hi: float = 225.0
43
- P_lo: float = 175.0
44
- sigma_hi: float = 50.0
45
- sigma_lo: float = 50.0
46
- seg_lo_range: tuple[float, float] = (10.0, 15.0)
47
- seg_hi_range: tuple[float, float] = (35.0, 40.0)
48
- dip_prob_per_sec: float = 0.010
49
- dip_depth_range: tuple[float, float] = (120.0, 125.0)
50
- dip_dur_range: tuple[float, float] = (0.06, 0.14)
51
- smooth_window_s: float = 0.30
52
- ramp_s: float = 18.0
53
- ramp_from: float = 50.0
54
-
55
-
56
- def _generate_training_like_trace(params: TrainingTraceParams) -> tuple[np.ndarray, np.ndarray]:
57
- """Generate a synthetic training-like per-GPU power trace.
58
-
59
- Args:
60
- params: Generation parameters.
61
-
62
- Returns:
63
- Tuple of (time_array, power_array).
64
- """
65
- rng = np.random.default_rng(params.seed)
66
- t = np.arange(0.0, params.duration_s, params.dt_s)
67
- n = t.size
68
-
69
- env = np.empty(n, dtype=float)
70
- i = 0
71
- state_hi = True
72
-
73
- while i < n:
74
- if state_hi:
75
- dur = rng.uniform(*params.seg_hi_range)
76
- level = params.P_hi
77
- else:
78
- dur = rng.uniform(*params.seg_lo_range)
79
- level = params.P_lo
80
-
81
- j = min(n, i + int(np.round(dur / params.dt_s)))
82
- env[i:j] = level
83
- i = j
84
- state_hi = not state_hi
85
-
86
- noise = np.zeros(n, dtype=float)
87
- hi_mask = env > (params.P_hi + params.P_lo) / 2
88
- noise[hi_mask] = rng.normal(0.0, params.sigma_hi, size=hi_mask.sum())
89
- noise[~hi_mask] = rng.normal(0.0, params.sigma_lo, size=(~hi_mask).sum())
90
-
91
- p = env + noise
92
-
93
- w = max(1, int(np.round(params.smooth_window_s / params.dt_s)))
94
- if w > 1:
95
- kernel = np.ones(w) / w
96
- p = np.convolve(p, kernel, mode="same")
97
-
98
- n_dips = rng.poisson(params.dip_prob_per_sec * params.duration_s)
99
- for _ in range(n_dips):
100
- t0 = rng.uniform(0.0, params.duration_s)
101
- k0 = int(t0 / params.dt_s)
102
- dur = rng.uniform(*params.dip_dur_range)
103
- k1 = min(n, k0 + int(np.round(dur / params.dt_s)))
104
- if k1 <= k0:
105
- continue
106
- depth = rng.uniform(*params.dip_depth_range)
107
- p[k0:k1] = np.maximum(p[k0:k1] - depth, 0.0)
108
-
109
- if params.ramp_s > 0:
110
- k_ramp = min(n, int(np.round(params.ramp_s / params.dt_s)))
111
- ramp = np.linspace(params.ramp_from, params.P_hi, k_ramp)
112
- p[:k_ramp] = np.minimum(p[:k_ramp], ramp)
113
-
114
- return t, p
115
-
116
-
117
- @dataclass(frozen=True)
118
- class TrainingTrace:
119
- """A single-GPU training power trace.
120
-
121
- Attributes:
122
- t_s: Time vector (seconds), monotonically increasing.
123
- power_w: Power vector (watts) for one GPU, same length as `t_s`.
124
- """
125
-
126
- COL_TIME = "t_s"
127
- COL_POWER = "power_W"
128
-
129
- t_s: np.ndarray
130
- power_w: np.ndarray
131
-
132
- def __post_init__(self) -> None:
133
- if len(self.t_s) != len(self.power_w):
134
- raise ValueError(f"t_s and power_w must have the same length, got {len(self.t_s)} and {len(self.power_w)}")
135
- if len(self.t_s) < 2:
136
- raise ValueError("Training trace must have >= 2 samples.")
137
-
138
- @classmethod
139
- def generate(cls, params: TrainingTraceParams | None = None) -> TrainingTrace:
140
- """Generate a synthetic training-like power trace.
141
-
142
- Args:
143
- params: Generation parameters. Uses defaults if `None`.
144
-
145
- Returns:
146
- A new [`TrainingTrace`][.] with generated data.
147
- """
148
- if params is None:
149
- params = TrainingTraceParams()
150
- t, p = _generate_training_like_trace(params)
151
- return cls(t_s=t, power_w=p)
152
-
153
- def save(self, csv_path: Path) -> None:
154
- """Save the trace to a CSV file.
155
-
156
- Args:
157
- csv_path: Output CSV path.
158
- """
159
- csv_path = Path(csv_path)
160
- csv_path.parent.mkdir(parents=True, exist_ok=True)
161
- df = pd.DataFrame({self.COL_TIME: self.t_s, self.COL_POWER: self.power_w})
162
- df.to_csv(csv_path, index=False)
163
-
164
- @classmethod
165
- def load(cls, csv_path: Path) -> TrainingTrace:
166
- """Load a training trace from CSV.
167
-
168
- Args:
169
- csv_path: Path to CSV with columns `t_s` and `power_W`.
170
- """
171
- csv_path = Path(csv_path)
172
- df = pd.read_csv(csv_path)
173
- if cls.COL_TIME not in df.columns or cls.COL_POWER not in df.columns:
174
- raise ValueError(
175
- f"{csv_path} must have columns {cls.COL_TIME!r} and {cls.COL_POWER!r}. Got {list(df.columns)}"
176
- )
177
-
178
- t = df[cls.COL_TIME].to_numpy(float)
179
- p = np.clip(df[cls.COL_POWER].to_numpy(float), 0.0, None)
180
-
181
- if np.any(np.diff(t) < 0):
182
- idx = np.argsort(t)
183
- t, p = t[idx], p[idx]
184
-
185
- return cls(t_s=t, power_w=p)
186
-
187
- @classmethod
188
- def ensure(cls, csv_path: Path, params: TrainingTraceParams | None = None) -> TrainingTrace:
189
- """Load from `csv_path`, generating first if needed.
190
-
191
- Args:
192
- csv_path: Path to the training trace CSV.
193
- params: Generation parameters. Required when no cached file exists.
194
- Uses defaults if `None` and generation is needed.
195
- """
196
- csv_path = Path(csv_path)
197
- if not csv_path.exists():
198
- logger.info("Generating training trace to %s ...", csv_path)
199
- cls.generate(params).save(csv_path)
200
- return cls.load(csv_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/events.py DELETED
@@ -1,60 +0,0 @@
1
- """Clock-aligned simulation event primitives."""
2
-
3
- from __future__ import annotations
4
-
5
- from dataclasses import dataclass, field
6
- from typing import TYPE_CHECKING, Any, Literal
7
-
8
- if TYPE_CHECKING:
9
- from openg2g.clock import SimulationClock
10
- from openg2g.coordinator import SimulationLog
11
-
12
- EventSource = Literal["coordinator", "controller", "datacenter", "grid", "custom"]
13
-
14
-
15
- @dataclass(frozen=True)
16
- class SimEvent:
17
- """Structured simulation event with canonical clock metadata.
18
-
19
- Attributes:
20
- tick: Integer tick at which the event was emitted.
21
- t_s: Simulation time in seconds.
22
- source: Component family that emitted the event.
23
- topic: Dot-separated event topic string.
24
- data: Arbitrary key-value payload.
25
- """
26
-
27
- tick: int
28
- t_s: float
29
- source: EventSource
30
- topic: str
31
- data: dict[str, Any] = field(default_factory=dict)
32
-
33
-
34
- @dataclass
35
- class EventEmitter:
36
- """Source-bound event helper that stamps [`SimEvent`][..SimEvent]
37
- instances with clock metadata.
38
-
39
- Attributes:
40
- clock: Simulation clock for timestamping events.
41
- log: `SimulationLog` that receives emitted events.
42
- source: Component family label attached to all events.
43
- """
44
-
45
- clock: SimulationClock
46
- log: SimulationLog
47
- source: EventSource
48
-
49
- def emit(self, topic: str, data: dict[str, Any] | None = None) -> None:
50
- """Emit one event with current clock metadata."""
51
- t_s = float(self.clock.time_s)
52
- self.log.emit(
53
- SimEvent(
54
- tick=int(self.clock.step),
55
- t_s=t_s,
56
- source=self.source,
57
- topic=str(topic),
58
- data={} if data is None else dict(data),
59
- )
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/grid/__init__.py DELETED
File without changes
openg2g/grid/base.py DELETED
@@ -1,203 +0,0 @@
1
- """Abstract base class for grid backends and grid-level types."""
2
-
3
- from __future__ import annotations
4
-
5
- from abc import ABC, abstractmethod
6
- from dataclasses import dataclass
7
- from fractions import Fraction
8
- from typing import Generic, TypeVar, final
9
-
10
- import numpy as np
11
-
12
- from openg2g.clock import SimulationClock
13
- from openg2g.common import ThreePhase
14
- from openg2g.events import EventEmitter
15
- from openg2g.grid.command import GridCommand
16
- from openg2g.grid.config import TapPosition
17
-
18
-
19
- @dataclass(frozen=True)
20
- class PhaseVoltages:
21
- """Per-phase voltage magnitudes in per-unit.
22
-
23
- Phases missing from the bus have NaN for that field.
24
-
25
- Attributes:
26
- a: Phase A voltage magnitude (pu).
27
- b: Phase B voltage magnitude (pu).
28
- c: Phase C voltage magnitude (pu).
29
- """
30
-
31
- a: float
32
- b: float
33
- c: float
34
-
35
-
36
- @dataclass(frozen=True)
37
- class BusVoltages:
38
- """Per-bus, per-phase voltage map.
39
-
40
- Access: voltages["671"].a -> Vpu for bus 671, phase A.
41
- Buses missing a phase have NaN for that field.
42
- """
43
-
44
- _data: dict[str, PhaseVoltages]
45
-
46
- def __getitem__(self, bus: str) -> PhaseVoltages:
47
- return self._data[bus]
48
-
49
- def buses(self) -> list[str]:
50
- """Return the list of bus names."""
51
- return list(self._data.keys())
52
-
53
- def __contains__(self, bus: str) -> bool:
54
- return bus in self._data
55
-
56
-
57
- @dataclass(frozen=True)
58
- class GridState:
59
- """State emitted by the grid simulator each timestep.
60
-
61
- Attributes:
62
- time_s: Simulation time in seconds.
63
- voltages: Per-bus, per-phase voltage magnitudes.
64
- tap_positions: Current regulator tap positions, or `None` if
65
- no regulator is present.
66
- """
67
-
68
- time_s: float
69
- voltages: BusVoltages
70
- tap_positions: TapPosition | None = None
71
-
72
-
73
- GridStateT = TypeVar("GridStateT", bound=GridState)
74
-
75
-
76
- class GridBackend(Generic[GridStateT], ABC):
77
- """Interface for grid simulation backends."""
78
-
79
- _INIT_SENTINEL = object()
80
-
81
- def __init__(self) -> None:
82
- self._state: GridStateT | None = None
83
- self._history: list[GridStateT] = []
84
- self._grid_base_init = GridBackend._INIT_SENTINEL
85
-
86
- def _check_base_init(self) -> None:
87
- if getattr(self, "_grid_base_init", None) is not GridBackend._INIT_SENTINEL:
88
- raise TypeError(f"{type(self).__name__}.__init__ must call super().__init__().")
89
-
90
- @property
91
- @abstractmethod
92
- def dt_s(self) -> Fraction:
93
- """Native timestep as a Fraction (seconds)."""
94
-
95
- @final
96
- @property
97
- def state(self) -> GridStateT:
98
- """Latest emitted state.
99
-
100
- Raises:
101
- RuntimeError: If accessed before the first `step()` call.
102
- """
103
- self._check_base_init()
104
- if self._state is None:
105
- raise RuntimeError(f"{type(self).__name__}.state accessed before first step().")
106
- return self._state
107
-
108
- @final
109
- def history(self, n: int | None = None) -> list[GridStateT]:
110
- """Return emitted state history (all, or latest `n`)."""
111
- self._check_base_init()
112
- if n is None:
113
- return list(self._history)
114
- if n <= 0:
115
- return []
116
- return list(self._history[-int(n) :])
117
-
118
- @final
119
- def do_step(
120
- self,
121
- clock: SimulationClock,
122
- power_samples_w: list[ThreePhase],
123
- events: EventEmitter,
124
- ) -> GridStateT:
125
- """Call `step`, record the state, and return it.
126
-
127
- Called by the coordinator. Subclasses should not override this.
128
- """
129
- self._check_base_init()
130
- state = self.step(clock, power_samples_w, events)
131
- self._state = state
132
- self._history.append(state)
133
- return state
134
-
135
- @abstractmethod
136
- def step(
137
- self,
138
- clock: SimulationClock,
139
- power_samples_w: list[ThreePhase],
140
- events: EventEmitter,
141
- ) -> GridStateT:
142
- """Advance one native timestep and return state for this step."""
143
-
144
- @abstractmethod
145
- def apply_control(self, command: GridCommand, events: EventEmitter) -> None:
146
- """Apply one control command."""
147
-
148
- @abstractmethod
149
- def voltages_vector(self) -> np.ndarray:
150
- """Return voltage magnitudes in `v_index` order."""
151
-
152
- @abstractmethod
153
- def estimate_sensitivity(self, perturbation_kw: float = 100.0) -> tuple[np.ndarray, np.ndarray]:
154
- """Estimate voltage sensitivity matrix (H = dv/dp) and return `(H, v0)`."""
155
-
156
- @property
157
- @abstractmethod
158
- def v_index(self) -> list[tuple[str, int]]:
159
- """Fixed (bus, phase) ordering used by [`voltages_vector`][..voltages_vector]."""
160
-
161
- @final
162
- def do_reset(self) -> None:
163
- """Clear history and call `reset`.
164
-
165
- Called by the coordinator. Subclasses should not override this.
166
- """
167
- self._check_base_init()
168
- self._state = None
169
- self._history.clear()
170
- self.reset()
171
-
172
- @abstractmethod
173
- def reset(self) -> None:
174
- """Reset simulation state to initial conditions.
175
-
176
- Called by the coordinator (via `do_reset`) before each
177
- [`start`][..start]. Must clear all simulation state: counters,
178
- cached values. Configuration (dt_s, case files, tap schedules)
179
- is not affected. History is cleared automatically by
180
- `do_reset`.
181
-
182
- Abstract so every implementation explicitly enumerates its state.
183
- A forgotten field is a bug -- not clearing it silently corrupts
184
- the second run.
185
- """
186
-
187
- def start(self) -> None:
188
- """Acquire per-run resources (solver circuits, connections).
189
-
190
- Called after [`reset`][..reset], before the simulation loop.
191
- Override for backends that need resource acquisition (e.g.,
192
- [`OpenDSSGrid`][openg2g.grid.opendss.OpenDSSGrid] compiles its
193
- DSS circuit here). No-op by default because most offline
194
- components have no resources to acquire.
195
- """
196
-
197
- def stop(self) -> None:
198
- """Release per-run resources. Simulation state is preserved.
199
-
200
- Called after the simulation loop in LIFO order. Override for
201
- backends that acquired resources in [`start`][..start]. No-op
202
- by default.
203
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/grid/command.py DELETED
@@ -1,31 +0,0 @@
1
- """Command types targeting grid backends."""
2
-
3
- from __future__ import annotations
4
-
5
- from dataclasses import dataclass
6
-
7
- from openg2g.grid.config import TapPosition
8
-
9
-
10
- class GridCommand:
11
- """Base for commands targeting the grid backend.
12
-
13
- Subclass this for each concrete grid command kind.
14
- The coordinator routes commands to backends based on this type hierarchy.
15
- """
16
-
17
- def __init__(self) -> None:
18
- if type(self) is GridCommand:
19
- raise TypeError("GridCommand cannot be instantiated directly; subclass it.")
20
-
21
-
22
- @dataclass(frozen=True)
23
- class SetTaps(GridCommand):
24
- """Set regulator tap positions.
25
-
26
- Attributes:
27
- tap_position: Per-phase tap ratios. Phases set to `None` are
28
- unchanged.
29
- """
30
-
31
- tap_position: TapPosition
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/grid/config.py DELETED
@@ -1,92 +0,0 @@
1
- """Grid configuration and schedule types."""
2
-
3
- from __future__ import annotations
4
-
5
- from collections.abc import Iterator
6
- from dataclasses import dataclass
7
-
8
-
9
- @dataclass(frozen=True)
10
- class TapPosition:
11
- """Regulator tap position per phase, as per-unit tap ratios.
12
-
13
- Each field is the tap ratio for the corresponding phase regulator.
14
- Phases set to `None` are left unchanged when applied. At least
15
- one phase must be specified.
16
-
17
- Combine with [`at`][.at] and `|` to build a [`TapSchedule`][..TapSchedule]:
18
-
19
- ```python
20
- TAP_STEP = 0.00625 # standard 5/8% tap step
21
- schedule = (
22
- TapPosition(a=1.0 + 14 * TAP_STEP, b=1.0 + 6 * TAP_STEP, c=1.0 + 15 * TAP_STEP).at(t=0)
23
- | TapPosition(a=1.1).at(t=1500)
24
- | TapPosition(a=1.0625, c=1.0625).at(t=3300)
25
- )
26
- ```
27
- """
28
-
29
- a: float | None = None
30
- b: float | None = None
31
- c: float | None = None
32
-
33
- def __post_init__(self) -> None:
34
- if self.a is None and self.b is None and self.c is None:
35
- raise ValueError("TapPosition requires at least one phase (a, b, or c).")
36
-
37
- def at(self, t: float) -> TapSchedule:
38
- """Schedule this position at time `t` seconds."""
39
- return TapSchedule(((t, self),))
40
-
41
-
42
- class TapSchedule:
43
- """Ordered sequence of scheduled tap positions.
44
-
45
- Build using [`TapPosition.at`][..TapPosition.at] and the `|` operator:
46
-
47
- ```python
48
- TAP_STEP = 0.00625 # standard 5/8% tap step
49
- schedule = (
50
- TapPosition(a=1.0 + 14 * TAP_STEP, b=1.0 + 6 * TAP_STEP, c=1.0 + 15 * TAP_STEP).at(t=0)
51
- | TapPosition(a=1.0 + 16 * TAP_STEP).at(t=25 * 60)
52
- )
53
- ```
54
-
55
- Raises:
56
- ValueError: If two entries share the same timestamp.
57
- """
58
-
59
- __slots__ = ("_entries",)
60
-
61
- def __init__(self, entries: tuple[tuple[float, TapPosition], ...]) -> None:
62
- self._entries = tuple(sorted(entries, key=lambda e: e[0]))
63
- times = [t for t, _ in self._entries]
64
- if len(times) != len(set(times)):
65
- seen: set[float] = set()
66
- dupes = sorted({t for t in times if t in seen or seen.add(t)})
67
- raise ValueError(f"TapSchedule has duplicate timestamps: {dupes}")
68
-
69
- def __or__(self, other: TapSchedule) -> TapSchedule:
70
- return TapSchedule(self._entries + other._entries)
71
-
72
- def __iter__(self) -> Iterator[tuple[float, TapPosition]]:
73
- return iter(self._entries)
74
-
75
- def __len__(self) -> int:
76
- return len(self._entries)
77
-
78
- def __bool__(self) -> bool:
79
- return bool(self._entries)
80
-
81
- def __repr__(self) -> str:
82
- parts: list[str] = []
83
- for t, p in self._entries:
84
- fields = []
85
- if p.a is not None:
86
- fields.append(f"a={p.a}")
87
- if p.b is not None:
88
- fields.append(f"b={p.b}")
89
- if p.c is not None:
90
- fields.append(f"c={p.c}")
91
- parts.append(f"TapPosition({', '.join(fields)}).at(t={t})")
92
- return " | ".join(parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/grid/opendss.py DELETED
@@ -1,476 +0,0 @@
1
- """OpenDSS-based grid simulator."""
2
-
3
- from __future__ import annotations
4
-
5
- import functools
6
- import logging
7
- import math
8
- from fractions import Fraction
9
- from pathlib import Path
10
- from typing import TYPE_CHECKING, Literal
11
-
12
- import numpy as np
13
-
14
- from openg2g.clock import SimulationClock
15
- from openg2g.common import ThreePhase
16
- from openg2g.events import EventEmitter
17
- from openg2g.grid.base import BusVoltages, GridBackend, GridState, PhaseVoltages
18
- from openg2g.grid.command import GridCommand, SetTaps
19
- from openg2g.grid.config import TapPosition
20
-
21
- if TYPE_CHECKING:
22
- from opendssdirect import dss
23
- else:
24
- try:
25
- from opendssdirect.OpenDSSDirect import OpenDSSDirect
26
-
27
- dss = OpenDSSDirect(prefer_lists=False)
28
- except ImportError:
29
- dss = None
30
-
31
- logger = logging.getLogger(__name__)
32
-
33
- _PHASES = (1, 2, 3)
34
- _PHASE_NAME = {1: "A", 2: "B", 3: "C"}
35
- _PHASE_TO_ATTR = {1: "a", 2: "b", 3: "c"}
36
- _DC_LOAD_NAMES = ("DataCenterA", "DataCenterB", "DataCenterC")
37
-
38
-
39
- class OpenDSSGrid(GridBackend[GridState]):
40
- """OpenDSS-based grid simulator for distribution-level voltage analysis.
41
-
42
- !!! Info
43
- `OpenDSSDirect.py` is required to use this component.
44
- Install with: `pip install openg2g[opendss]`.
45
-
46
- This component uses OpenDSS purely as a power flow solver. The user's DSS
47
- case file defines the network topology and any built-in controls (voltage
48
- regulators, capacitor banks, etc.). The `dss_controls` flag determines
49
- whether OpenDSS iterates those controls during each solve:
50
-
51
- - `dss_controls=False` (default): Uses `SolveNoControl()`. OpenDSS runs
52
- a single power flow without iterating any built-in control loops.
53
- RegControls are disabled after initial tap setting. All voltage
54
- regulation is managed externally through
55
- [`apply_control`][.apply_control] commands (e.g., from
56
- [`TapScheduleController`][openg2g.controller.tap_schedule.TapScheduleController]
57
- or
58
- [`OFOBatchSizeController`][openg2g.controller.ofo.OFOBatchSizeController]).
59
-
60
- - `dss_controls=True`: Uses `Solve()`. OpenDSS iterates its built-in
61
- control loops (RegControls, CapControls, etc.) as defined in the case
62
- file. Use this when you want DSS-native control automation.
63
-
64
- Args:
65
- dss_case_dir: Absolute path to the directory containing OpenDSS case
66
- files (e.g. line codes, bus coordinates).
67
- dss_master_file: Name of the master DSS file, relative to
68
- `dss_case_dir` (e.g. `"IEEE13Nodeckt.dss"`). OpenDSS resolves
69
- all `redirect` and `BusCoords` paths in the master file
70
- relative to this directory.
71
- dc_bus: Bus name where the datacenter is connected.
72
- dc_bus_kv: Line-to-line voltage (kV) at the datacenter bus.
73
- power_factor: Power factor of the datacenter loads.
74
- dt_s: Grid simulation timestep (seconds).
75
- connection_type: Connection type for DC loads (default `"wye"`).
76
- dss_controls: Whether to let OpenDSS iterate its built-in control
77
- loops during each solve. Default False.
78
- initial_tap_position: Initial regulator tap position applied before
79
- the first solve. Each field is a per-unit tap ratio.
80
- exclude_buses: Buses to exclude from voltage indexing (e.g., source bus).
81
- """
82
-
83
- def __init__(
84
- self,
85
- *,
86
- dss_case_dir: str | Path,
87
- dss_master_file: str,
88
- dc_bus: str,
89
- dc_bus_kv: float,
90
- power_factor: float,
91
- dt_s: Fraction = Fraction(1),
92
- connection_type: Literal["wye", "delta"] = "wye",
93
- dss_controls: bool = False,
94
- initial_tap_position: TapPosition | None = None,
95
- exclude_buses: tuple[str, ...] = ("rg60",),
96
- ) -> None:
97
- super().__init__()
98
- if dss is None:
99
- raise RuntimeError("OpenDSSDirect is required. Install with: pip install openg2g[opendss]")
100
-
101
- self._case_dir = str(Path(dss_case_dir).resolve())
102
- self._master = str(dss_master_file)
103
- self._dc_bus = str(dc_bus)
104
- self._dc_bus_kv = float(dc_bus_kv)
105
- self._power_factor = float(power_factor)
106
- pf = max(min(self._power_factor, 0.999999), 1e-6)
107
- self._tanphi = math.tan(math.acos(pf))
108
- self._dt_s = dt_s
109
- self._connection_type: Literal["wye", "delta"] = connection_type
110
- self._dss_controls = bool(dss_controls)
111
-
112
- self._initial_tap_position = initial_tap_position
113
- self._reg_map: dict[str, tuple[str, int, int]] | None = None
114
- self._phase_to_reg: dict[int, str] | None = None
115
- self._exclude_buses = tuple(str(b) for b in exclude_buses)
116
-
117
- # Simulation state (cleared by reset)
118
- self._prev_power: ThreePhase | None = None
119
-
120
- # DSS-derived data (populated by start)
121
- self._started = False
122
- self.all_buses: list[str] = []
123
- self.buses_with_phase: dict[int, list[str]] = {}
124
- self._v_index: list[tuple[str, int]] = []
125
-
126
- @property
127
- def dt_s(self) -> Fraction:
128
- return self._dt_s
129
-
130
- @property
131
- def v_index(self) -> list[tuple[str, int]]:
132
- if not self._started:
133
- raise RuntimeError("OpenDSSGrid.v_index accessed before start().")
134
- return list(self._v_index)
135
-
136
- def step(
137
- self,
138
- clock: SimulationClock,
139
- power_samples_w: list[ThreePhase],
140
- events: EventEmitter,
141
- ) -> GridState:
142
- """Advance one grid period and return the resulting grid state.
143
-
144
- Uses the most recent power sample from the accumulated buffer to
145
- run a single power flow solve. If no samples are provided (grid
146
- runs faster than datacenter), the last known power is reused.
147
-
148
- Args:
149
- clock: Current simulation clock.
150
- power_samples_w: List of
151
- [`ThreePhase`][openg2g.common.ThreePhase] power samples
152
- (Watts) accumulated since the last grid step.
153
-
154
- Returns:
155
- [`GridState`][openg2g.grid.base.GridState] with voltages
156
- from the solve.
157
- """
158
- if not power_samples_w:
159
- if self._prev_power is None:
160
- raise RuntimeError("OpenDSSGrid.step() called with no power samples and no previous power.")
161
- power = self._prev_power
162
- else:
163
- power = power_samples_w[-1]
164
-
165
- self._prev_power = power
166
-
167
- kW_A = power.a / 1e3
168
- kW_B = power.b / 1e3
169
- kW_C = power.c / 1e3
170
-
171
- for name, kw in zip(_DC_LOAD_NAMES, (kW_A, kW_B, kW_C), strict=True):
172
- dss.Loads.Name(name)
173
- dss.Loads.kW(kw)
174
- dss.Loads.kvar(kw * self._tanphi)
175
-
176
- self._solve()
177
-
178
- voltages = self._snapshot_bus_voltages()
179
- return GridState(time_s=clock.time_s, voltages=voltages, tap_positions=self._read_current_taps())
180
-
181
- @functools.singledispatchmethod
182
- def apply_control(self, command: GridCommand, events: EventEmitter) -> None:
183
- """Apply a control command. Dispatches on command type."""
184
- raise TypeError(f"OpenDSSGrid does not support {type(command).__name__}")
185
-
186
- @apply_control.register
187
- def apply_control_set_taps(self, command: SetTaps, events: EventEmitter) -> None:
188
- tap_map = self._tap_position_to_reg_dict(command.tap_position)
189
- self._set_reg_taps(tap_map)
190
- events.emit(
191
- "grid.taps.updated",
192
- {"tap_position": command.tap_position},
193
- )
194
-
195
- def reset(self) -> None:
196
- self._prev_power = None
197
- self._started = False
198
-
199
- def start(self) -> None:
200
- self._init_dss()
201
- self._v_index = self._build_v_index()
202
- self._build_vmag_indices()
203
- self._build_snapshot_indices()
204
- self._started = True
205
- logger.info(
206
- "OpenDSSGrid: case=%s, dc_bus=%s, dt=%s s, dss_controls=%s, %d buses, %d bus-phase pairs",
207
- self._master,
208
- self._dc_bus,
209
- self._dt_s,
210
- self._dss_controls,
211
- len(self.all_buses),
212
- len(self._v_index),
213
- )
214
-
215
- def voltages_vector(self) -> np.ndarray:
216
- """Return voltage magnitudes (pu) in the fixed
217
- [`v_index`][openg2g.grid.base.GridBackend.v_index] ordering."""
218
- if not self._started:
219
- raise RuntimeError("OpenDSSGrid.voltages_vector() called before start().")
220
- vmag = dss.Circuit.AllBusMagPu()
221
- return vmag[self._v_index_to_vmag]
222
-
223
- def estimate_sensitivity(
224
- self,
225
- perturbation_kw: float = 100.0,
226
- ) -> tuple[np.ndarray, np.ndarray]:
227
- """Estimate voltage sensitivity matrix H = dv/dp (pu per kW).
228
-
229
- Uses finite differences on the 3 single-phase DC loads.
230
-
231
- Returns:
232
- Tuple of `(sensitivity, baseline_voltages)`.
233
- `sensitivity` has shape `(M, 3)` where M is the number
234
- of bus-phase pairs in
235
- [`v_index`][openg2g.grid.base.GridBackend.v_index].
236
- `baseline_voltages` has shape `(M,)`.
237
- """
238
- perturbation_kw = float(perturbation_kw)
239
- if perturbation_kw <= 0:
240
- raise ValueError("perturbation_kw must be positive.")
241
-
242
- dq_kvar = perturbation_kw * self._tanphi
243
-
244
- # Always use SolveNoControl so that DSS-native controls
245
- # (RegControls, CapControls) don't move between the baseline
246
- # and perturbed solves. We need the open-loop plant sensitivity
247
- # dv/dp, not the closed-loop response.
248
- dss.Solution.SolveNoControl()
249
- baseline_voltages = self.voltages_vector()
250
-
251
- # Baseline P, Q for each DC load
252
- p0 = np.zeros(3, dtype=float)
253
- q0 = np.zeros(3, dtype=float)
254
- for j, ld in enumerate(_DC_LOAD_NAMES):
255
- dss.Loads.Name(ld)
256
- p0[j] = float(dss.Loads.kW())
257
- q0[j] = float(dss.Loads.kvar())
258
-
259
- M = len(self._v_index)
260
- sensitivity = np.zeros((M, 3), dtype=float)
261
-
262
- for j, ld in enumerate(_DC_LOAD_NAMES):
263
- dss.Text.Command(f"Edit Load.{ld} kW={p0[j] + perturbation_kw:.6f} kvar={q0[j] + dq_kvar:.6f}")
264
- dss.Solution.SolveNoControl()
265
-
266
- sensitivity[:, j] = (self.voltages_vector() - baseline_voltages) / perturbation_kw
267
-
268
- # Restore load to baseline before next perturbation
269
- dss.Text.Command(f"Edit Load.{ld} kW={p0[j]:.6f} kvar={q0[j]:.6f}")
270
-
271
- # Re-solve with all loads restored (use normal solve to leave
272
- # DSS in its expected state for subsequent step() calls)
273
- self._solve()
274
-
275
- return sensitivity, baseline_voltages
276
-
277
- def _init_dss(self) -> None:
278
- dss.Basic.ClearAll()
279
- master_path = str(Path(self._case_dir) / self._master)
280
- dss.Text.Command(f'Compile "{master_path}"')
281
-
282
- self._reg_map = self._cache_regcontrol_map()
283
- self._phase_to_reg = self._build_phase_to_reg_map(self._reg_map)
284
-
285
- # Add 3 single-phase DC loads
286
- if self._connection_type == "wye":
287
- load_kv = self._dc_bus_kv / math.sqrt(3.0)
288
- elif self._connection_type == "delta":
289
- load_kv = self._dc_bus_kv
290
- else:
291
- raise ValueError(f"Unsupported connection_type: {self._connection_type!r}")
292
- for ph, nm in zip(_PHASES, _DC_LOAD_NAMES, strict=True):
293
- dss.Text.Command(
294
- f"New Load.{nm} bus1={self._dc_bus}.{ph} phases=1 "
295
- f"conn={self._connection_type} kV={load_kv:.6f} kW=0 kvar=0 model=1"
296
- )
297
-
298
- dss.Text.Command("Reset")
299
- dss.Text.Command("Set Mode=Time")
300
- dss.Text.Command(f"Set Stepsize={float(self._dt_s)}s")
301
- if self._dss_controls:
302
- dss.Text.Command("Set ControlMode=Time")
303
- else:
304
- dss.Text.Command("Set ControlMode=Off")
305
-
306
- if self._initial_tap_position is not None:
307
- self._set_reg_taps(self._tap_position_to_reg_dict(self._initial_tap_position))
308
-
309
- self._solve()
310
- self._cache_node_map()
311
- self._cache_buses_with_phases()
312
-
313
- def _solve(self) -> None:
314
- """Run the OpenDSS power flow solver."""
315
- if self._dss_controls:
316
- dss.Solution.Solve()
317
- else:
318
- dss.Solution.SolveNoControl()
319
-
320
- def _cache_buses_with_phases(self) -> None:
321
- """Populate `all_buses` and `buses_with_phase` from the compiled circuit."""
322
- self.all_buses = list(dss.Circuit.AllBusNames())
323
- self.buses_with_phase = {ph: [] for ph in _PHASES}
324
- for bus, phase in self._node_map:
325
- if phase in _PHASES:
326
- self.buses_with_phase[phase].append(bus)
327
-
328
- def _cache_node_map(self) -> None:
329
- """Cache the mapping from AllBusMagPu indices to (bus, phase) pairs."""
330
- self._node_map: list[tuple[str, int]] = []
331
- for name in dss.Circuit.AllNodeNames():
332
- parts = name.split(".")
333
- bus = parts[0]
334
- phase = int(parts[1]) if len(parts) > 1 else 0
335
- self._node_map.append((bus, phase))
336
-
337
- def _build_vmag_indices(self) -> None:
338
- """Pre-compute index arrays for fast voltage vector extraction."""
339
- node_idx = {(bus, ph): i for i, (bus, ph) in enumerate(self._node_map)}
340
- self._v_index_to_vmag = np.array(
341
- [node_idx[(bus, ph)] for bus, ph in self._v_index],
342
- dtype=int,
343
- )
344
-
345
- def _build_snapshot_indices(self) -> None:
346
- """Pre-compute index arrays for `_snapshot_bus_voltages`.
347
-
348
- Builds a `(num_buses, 3)` array where entry `[b, p]` is the
349
- index into `AllBusMagPu()` for bus `b`, phase `p+1`, or -1 if
350
- that bus-phase pair doesn't exist (mapped to NaN at read time).
351
- """
352
- bus_to_idx = {bus: i for i, bus in enumerate(self.all_buses)}
353
- n_buses = len(self.all_buses)
354
- # -1 means "missing phase -> NaN"
355
- self._snap_indices = np.full((n_buses, 3), -1, dtype=int)
356
- for vmag_idx, (bus, phase) in enumerate(self._node_map):
357
- if 1 <= phase <= 3:
358
- bus_idx = bus_to_idx.get(bus)
359
- if bus_idx is not None:
360
- self._snap_indices[bus_idx, phase - 1] = vmag_idx
361
-
362
- def _snapshot_bus_voltages(self) -> BusVoltages:
363
- """Snapshot all per-bus, per-phase voltage magnitudes into BusVoltages.
364
-
365
- Uses pre-computed index arrays and a single `AllBusMagPu()` bulk
366
- read. Missing bus-phase pairs (index == -1) are set to NaN.
367
- """
368
- vmag = dss.Circuit.AllBusMagPu()
369
- # Append a NaN sentinel so index -1 reads as NaN
370
- vmag_ext = np.append(vmag, float("nan"))
371
- volts = vmag_ext[self._snap_indices]
372
- data = {
373
- bus: PhaseVoltages(a=float(volts[i, 0]), b=float(volts[i, 1]), c=float(volts[i, 2]))
374
- for i, bus in enumerate(self.all_buses)
375
- }
376
- return BusVoltages(_data=data)
377
-
378
- def _build_v_index(self) -> list[tuple[str, int]]:
379
- excl = {b.lower() for b in self._exclude_buses}
380
- v_index: list[tuple[str, int]] = []
381
- for ph in _PHASES:
382
- for b in self.buses_with_phase.get(ph, []):
383
- if str(b).lower() in excl:
384
- continue
385
- v_index.append((str(b), int(ph)))
386
- return v_index
387
-
388
- @staticmethod
389
- def _cache_regcontrol_map() -> dict[str, tuple[str, int, int]]:
390
- """Enumerate RegControls and discover their transformer, winding, and phase.
391
-
392
- Returns:
393
- Mapping of `rc_name -> (transformer_name, winding, phase)` where
394
- phase is 1/2/3 for A/B/C. Phase is determined from the
395
- transformer's bus connections (e.g., `"650.1"` -> phase 1).
396
- """
397
- reg_map: dict[str, tuple[str, int, int]] = {}
398
- for rc in dss.RegControls:
399
- rc_name = rc.Name().lower()
400
- xf = rc.Transformer()
401
- w = int(rc.Winding())
402
-
403
- # Discover phase from transformer bus connections
404
- dss.Transformers.Name(xf)
405
- bus_names = list(dss.CktElement.BusNames())
406
- phase = 0
407
- for bus_str in bus_names:
408
- parts = str(bus_str).split(".")
409
- if len(parts) >= 2:
410
- phase = int(parts[1])
411
- break
412
- if phase not in (1, 2, 3):
413
- raise RuntimeError(
414
- f"Cannot determine phase for RegControl '{rc_name}' "
415
- f"(transformer={xf}, buses={bus_names}). "
416
- f"Expected bus format 'name.phase' with phase in {{1,2,3}}."
417
- )
418
-
419
- reg_map[rc_name] = (xf, w, phase)
420
- return reg_map
421
-
422
- @staticmethod
423
- def _build_phase_to_reg_map(reg_map: dict[str, tuple[str, int, int]]) -> dict[int, str]:
424
- """Build reverse mapping from phase (1/2/3) to RegControl name."""
425
- phase_to_reg: dict[int, str] = {}
426
- for rc_name, (_xf, _wdg, phase) in reg_map.items():
427
- if phase in phase_to_reg:
428
- logger.warning(
429
- "Multiple RegControls on phase %s: '%s' and '%s'. Using '%s'.",
430
- _PHASE_NAME[phase],
431
- phase_to_reg[phase],
432
- rc_name,
433
- rc_name,
434
- )
435
- phase_to_reg[phase] = rc_name
436
- return phase_to_reg
437
-
438
- def _tap_position_to_reg_dict(self, pos: TapPosition) -> dict[str, float]:
439
- """Map phase tap ratios to OpenDSS RegControl names using discovered mapping."""
440
- if self._phase_to_reg is None:
441
- raise RuntimeError("_phase_to_reg not initialized; call start() first")
442
- d: dict[str, float] = {}
443
- for phase, attr in _PHASE_TO_ATTR.items():
444
- val = getattr(pos, attr)
445
- if val is not None and phase in self._phase_to_reg:
446
- d[self._phase_to_reg[phase]] = val
447
- return d
448
-
449
- def _set_reg_taps(self, tap_map: dict[str, float]) -> None:
450
- """Write tap ratios to OpenDSS RegControl transformers."""
451
- if self._reg_map is None:
452
- self._reg_map = self._cache_regcontrol_map()
453
-
454
- tap_map_lc = {str(k).lower(): float(v) for k, v in tap_map.items()}
455
-
456
- for rc_key, (xfmr, wdg, _phase) in self._reg_map.items():
457
- if rc_key in tap_map_lc:
458
- tap_pu = tap_map_lc[rc_key]
459
- dss.Text.Command(f"Edit Transformer.{xfmr} Wdg={wdg} Tap={tap_pu:.6f}")
460
-
461
- def _read_current_taps(self) -> TapPosition:
462
- """Read current regulator tap positions from OpenDSS."""
463
- if self._reg_map is None:
464
- self._reg_map = self._cache_regcontrol_map()
465
- if self._phase_to_reg is None:
466
- self._phase_to_reg = self._build_phase_to_reg_map(self._reg_map)
467
-
468
- phase_taps: dict[str, float | None] = {"a": None, "b": None, "c": None}
469
- for _rc_key, (xfmr, wdg, phase) in self._reg_map.items():
470
- dss.Transformers.Name(xfmr)
471
- dss.Transformers.Wdg(wdg)
472
- attr = _PHASE_TO_ATTR.get(phase)
473
- if attr is not None:
474
- phase_taps[attr] = float(dss.Transformers.Tap())
475
-
476
- return TapPosition(a=phase_taps["a"], b=phase_taps["b"], c=phase_taps["c"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/metrics/__init__.py DELETED
File without changes
openg2g/metrics/voltage.py DELETED
@@ -1,94 +0,0 @@
1
- """Voltage violation metrics for all-bus, all-phase analysis."""
2
-
3
- from __future__ import annotations
4
-
5
- from dataclasses import dataclass
6
-
7
- import numpy as np
8
-
9
- from openg2g.grid.base import GridState
10
-
11
-
12
- @dataclass
13
- class VoltageStats:
14
- """Summary voltage statistics over a simulation run.
15
-
16
- Attributes:
17
- worst_vmin: Lowest voltage observed across all buses and phases (pu).
18
- worst_vmax: Highest voltage observed across all buses and phases (pu).
19
- violation_time_s: Total time with at least one bus-phase violating
20
- voltage bounds (seconds).
21
- integral_violation_pu_s: Integrated voltage violation magnitude
22
- across all bus-phase pairs (pu * s).
23
- """
24
-
25
- worst_vmin: float
26
- worst_vmax: float
27
- violation_time_s: float
28
- integral_violation_pu_s: float
29
-
30
-
31
- def compute_allbus_voltage_stats(
32
- grid_states: list[GridState],
33
- *,
34
- v_min: float = 0.95,
35
- v_max: float = 1.05,
36
- exclude_buses: tuple[str, ...] = ("rg60",),
37
- ) -> VoltageStats:
38
- """Compute voltage violation statistics across all buses and phases.
39
-
40
- For each snapshot the integral violation sums
41
- `max(v_min - v, 0) + max(v - v_max, 0)` over every non-excluded
42
- bus-phase pair, then integrates over time. A snapshot counts as
43
- "violated" when this sum is positive.
44
-
45
- Args:
46
- grid_states: Sequence of [`GridState`][openg2g.grid.base.GridState]
47
- objects from a simulation run.
48
- v_min: Lower voltage bound (pu).
49
- v_max: Upper voltage bound (pu).
50
- exclude_buses: Bus names to exclude from statistics (case-insensitive).
51
- """
52
- if len(grid_states) < 2:
53
- raise ValueError(
54
- f"At least two grid states are required to compute voltage statistics (got {len(grid_states)})."
55
- )
56
-
57
- times = np.array([gs.time_s for gs in grid_states], dtype=float)
58
- dt = float(np.median(np.diff(times)))
59
-
60
- # Collect bus-phase columns from the first snapshot (all snapshots
61
- # share the same set of buses for a given OpenDSS circuit).
62
- exclude = {b.lower() for b in exclude_buses}
63
- bus_names = [b for b in grid_states[0].voltages.buses() if b.lower() not in exclude]
64
-
65
- # Build (T, N) voltage matrix where N = num_buses * 3.
66
- T = len(grid_states)
67
- N = len(bus_names) * 3
68
- V = np.empty((T, N), dtype=float)
69
- for t, gs in enumerate(grid_states):
70
- col = 0
71
- for bus in bus_names:
72
- tp = gs.voltages[bus]
73
- V[t, col] = tp.a
74
- V[t, col + 1] = tp.b
75
- V[t, col + 2] = tp.c
76
- col += 3
77
-
78
- valid = ~np.isnan(V)
79
- worst_vmin = float(np.min(np.where(valid, V, np.inf)))
80
- worst_vmax = float(np.max(np.where(valid, V, -np.inf)))
81
-
82
- # Per-timestep violation: sum over all bus-phase pairs
83
- viol = np.where(valid, np.maximum(v_min - V, 0.0) + np.maximum(V - v_max, 0.0), 0.0)
84
- viol_sum = np.sum(viol, axis=1) # shape (T,)
85
-
86
- violation_steps = int(np.count_nonzero(viol_sum > 0.0))
87
- integral_violation = float(np.sum(viol_sum * dt))
88
-
89
- return VoltageStats(
90
- worst_vmin=float(worst_vmin),
91
- worst_vmax=float(worst_vmax),
92
- violation_time_s=float(violation_steps * dt),
93
- integral_violation_pu_s=float(integral_violation),
94
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openg2g/utils.py DELETED
@@ -1,18 +0,0 @@
1
- """Shared utility functions."""
2
-
3
- from __future__ import annotations
4
-
5
-
6
- def split_integer_evenly(n: int, k: int) -> list[int]:
7
- """Split integer *n* into *k* non-negative integers whose sum is *n*,
8
- differing by at most 1.
9
-
10
- Example:
11
-
12
- ```python
13
- split_integer_evenly(10, 3) # -> [4, 3, 3]
14
- split_integer_evenly(2, 5) # -> [1, 1, 0, 0, 0]
15
- ```
16
- """
17
- q, r = divmod(int(n), int(k))
18
- return [q + (1 if i < r else 0) for i in range(k)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml CHANGED
@@ -3,8 +3,8 @@ requires = ["setuptools>=64"]
3
  build-backend = "setuptools.build_meta"
4
 
5
  [project]
6
- name = "openg2g"
7
- dynamic = ["version"]
8
  description = "A GPU-to-Grid simulation library for datacenter-grid cooperation."
9
  requires-python = ">=3.10"
10
  license = "Apache-2.0"
@@ -16,6 +16,8 @@ dependencies = [
16
  "aiohttp",
17
  "zeus>=0.15.0",
18
  "mlenergy-data",
 
 
19
  ]
20
 
21
  [project.urls]
@@ -38,11 +40,11 @@ dev = [
38
  {include-group = "examples"},
39
  ]
40
 
41
- [tool.setuptools.dynamic]
42
- version = {attr = "openg2g.__version__"}
43
-
44
  [tool.setuptools.packages.find]
45
- include = ["openg2g*"]
 
 
 
46
 
47
  [tool.ruff]
48
  target-version = "py310"
 
3
  build-backend = "setuptools.build_meta"
4
 
5
  [project]
6
+ name = "bus-system-backend"
7
+ version = "0.1.0"
8
  description = "A GPU-to-Grid simulation library for datacenter-grid cooperation."
9
  requires-python = ">=3.10"
10
  license = "Apache-2.0"
 
16
  "aiohttp",
17
  "zeus>=0.15.0",
18
  "mlenergy-data",
19
+ "openg2g[opendss]",
20
+
21
  ]
22
 
23
  [project.urls]
 
40
  {include-group = "examples"},
41
  ]
42
 
 
 
 
43
  [tool.setuptools.packages.find]
44
+ where = ["."]
45
+ include = ["your_project_name*"]
46
+ exclude = ["data*", "outputs*", "scripts*", "tests*"]
47
+
48
 
49
  [tool.ruff]
50
  target-version = "py310"
server.py CHANGED
@@ -33,29 +33,25 @@ from openg2g.grid.config import TapPosition
33
  from openg2g.controller.tap_schedule import TapScheduleController
34
  from openg2g.metrics.voltage import compute_allbus_voltage_stats
35
 
36
- import asyncio, uuid, time
37
- from concurrent.futures import ProcessPoolExecutor
38
 
39
- import sqlite3, json
 
 
 
 
40
 
41
- conn = sqlite3.connect("jobs.db", check_same_thread=False, timeout=30)
42
- conn.execute("PRAGMA journal_mode=WAL;")
43
 
44
 
45
- # create table to track background simulation jobs
46
- conn.execute("""
47
- CREATE TABLE IF NOT EXISTS jobs (
48
- id TEXT PRIMARY KEY,
49
- status TEXT,
50
- result TEXT,
51
- error TEXT
52
- )
53
- """)
54
- conn.commit()
55
 
56
  #currently set to 2 for free tier at hf
57
  _pool = ProcessPoolExecutor(max_workers=2)
58
- _jobs: dict = {}
59
  _start_time = time.time()
60
 
61
 
@@ -134,7 +130,7 @@ def _get_trace_power(model_label: str, num_gpus: int, max_num_seqs: int,
134
  return [p * num_replicas for p in power_W]
135
 
136
 
137
- print(f" [startup] data dir: {_DATA_DIR} exists={_DATA_DIR.exists()}")
138
  _load_traces_index() # load at startup
139
 
140
 
@@ -330,7 +326,7 @@ def _run_full(req_dict: dict) -> dict:
330
 
331
 
332
  """Get per-bus voltage (worst phase per bus)."""
333
- def _voltages(gs, debug=False) -> list[float]:
334
  result = []
335
  for name in BUSES_ORDERED:
336
  try:
@@ -338,13 +334,13 @@ def _voltages(gs, debug=False) -> list[float]:
338
  vals = [float(v) for v in [tp.a, tp.b, tp.c]
339
  if not math.isnan(float(v)) and 0.5 < float(v) < 1.5]
340
  result.append(min(vals) if vals else None)
341
- except Exception:
 
342
  result.append(None)
343
  known = [v for v in result if v is not None]
344
  avg = sum(known) / len(known) if known else 1.0
345
  result = [v if v is not None else avg for v in result]
346
- if debug:
347
- print(f" [V] {[round(v,4) for v in result]}")
348
  return result
349
 
350
 
@@ -390,43 +386,10 @@ def health():
390
 
391
 
392
 
393
- @app.get("/api/status")
394
- def status():
395
- active = conn.execute(
396
- "SELECT COUNT(*) FROM jobs WHERE status='pending'"
397
- ).fetchone()[0]
398
-
399
- total = conn.execute(
400
- "SELECT COUNT(*) FROM jobs"
401
- ).fetchone()[0]
402
 
403
- return {
404
- "active_jobs": active,
405
- "total_jobs": total,
406
- "workers": _pool._max_workers,
407
- }
408
 
409
 
410
 
411
- @app.get("/api/job/{job_id}")
412
- def get_job(job_id: str):
413
- row = conn.execute(
414
- "SELECT status, result, error FROM jobs WHERE id=?",
415
- (job_id,)
416
- ).fetchone()
417
-
418
- if not row:
419
- raise HTTPException(404, "Job not found")
420
-
421
- status, result, error = row
422
-
423
- if status == "done":
424
- return {"status": status, "result": json.loads(result)}
425
- elif status == "error":
426
- return {"status": status, "detail": error}
427
- else:
428
- return {"status": status}
429
-
430
 
431
  """Return available traces"""
432
  @app.get("/api/traces")
@@ -459,18 +422,18 @@ def list_traces():
459
  """Baseline grid simulation, no workload"""
460
  @app.post("/api/powerflow")
461
  async def powerflow(req: PowerflowRequest):
462
- print(f"\nPowerflow v={req.substationVoltage}")
463
  try:
464
  dc = _build_dc(scale=0.001, duration_s=5)
465
  grid = _build_grid(req.substationVoltage, "671")
466
  log = _run(dc, grid, req.substationVoltage, "671", 5)
467
- vs = _voltages(log.grid_states[-1], debug=True)
468
- print(f" min={min(vs):.4f} max={max(vs):.4f}")
469
  return {"buses": [{"id": i+1, "voltage": v, "activePower": 0.0,
470
  "reactivePower": 0.0} for i, v in enumerate(vs)],
471
  "lines": []}
472
  except Exception as e:
473
- import traceback; traceback.print_exc()
474
  raise HTTPException(status_code=500, detail=str(e))
475
 
476
 
@@ -478,34 +441,16 @@ async def powerflow(req: PowerflowRequest):
478
  """Simulate AI workload impact on grid using GPU traces."""
479
  @app.post("/api/llm-impact")
480
  async def llm_impact(req: LLMImpactRequest):
481
- job_id = uuid.uuid4().hex
482
-
483
- conn.execute(
484
- "INSERT INTO jobs (id, status) VALUES (?, ?)",
485
- (job_id, "pending")
486
- )
487
- conn.commit()
488
-
489
- async def run_and_store():
490
- try:
491
- loop = asyncio.get_event_loop()
492
- result = await loop.run_in_executor(_pool, _run_full, req.dict())
493
-
494
- conn.execute(
495
- "UPDATE jobs SET status=?, result=? WHERE id=?",
496
- ("done", json.dumps(result), job_id)
497
- )
498
- conn.commit()
499
 
500
- except Exception as e:
501
- conn.execute(
502
- "UPDATE jobs SET status=?, error=? WHERE id=?",
503
- ("error", str(e), job_id)
504
- )
505
- conn.commit()
506
 
507
- asyncio.create_task(run_and_store())
508
- return {"job_id": job_id}
 
509
 
510
 
511
 
@@ -527,13 +472,12 @@ async def heatmap(req: HeatmapRequest):
527
 
528
 
529
  if __name__ == "__main__":
530
- print("\n" + "="*70)
531
- print("="*70)
532
- print(f" Data: {_DATA_DIR} ready={_DATA_DIR.exists()}")
533
  df = _load_traces_index()
534
  if not df.empty:
535
  models = df["model_label"].unique().tolist()
536
- print(f" Models: {models}")
537
- print(f" Traces: {len(df)} configurations")
538
- print("="*70 + "\n")
539
  uvicorn.run("server:app", host="0.0.0.0", port=8080, workers=1, log_level="info")
 
33
  from openg2g.controller.tap_schedule import TapScheduleController
34
  from openg2g.metrics.voltage import compute_allbus_voltage_stats
35
 
36
+ import logging
 
37
 
38
+ logging.basicConfig(
39
+ level=logging.INFO,
40
+ format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
41
+ )
42
+ logger = logging.getLogger(__name__)
43
 
 
 
44
 
45
 
46
+ import asyncio, time
47
+ from concurrent.futures import ProcessPoolExecutor
48
+
49
+ import json
50
+
 
 
 
 
 
51
 
52
  #currently set to 2 for free tier at hf
53
  _pool = ProcessPoolExecutor(max_workers=2)
54
+
55
  _start_time = time.time()
56
 
57
 
 
130
  return [p * num_replicas for p in power_W]
131
 
132
 
133
+ logger.info(f"Data dir: {_DATA_DIR} exists={_DATA_DIR.exists()}")
134
  _load_traces_index() # load at startup
135
 
136
 
 
326
 
327
 
328
  """Get per-bus voltage (worst phase per bus)."""
329
+ def _voltages(gs) -> list[float]:
330
  result = []
331
  for name in BUSES_ORDERED:
332
  try:
 
334
  vals = [float(v) for v in [tp.a, tp.b, tp.c]
335
  if not math.isnan(float(v)) and 0.5 < float(v) < 1.5]
336
  result.append(min(vals) if vals else None)
337
+ except Exception as e:
338
+ logger.debug(f"Bus {name} voltage unavailable: {e}")
339
  result.append(None)
340
  known = [v for v in result if v is not None]
341
  avg = sum(known) / len(known) if known else 1.0
342
  result = [v if v is not None else avg for v in result]
343
+ logger.debug(f"Voltages: {[round(v,4) for v in result]}")
 
344
  return result
345
 
346
 
 
386
 
387
 
388
 
 
 
 
 
 
 
 
 
 
389
 
 
 
 
 
 
390
 
391
 
392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
  """Return available traces"""
395
  @app.get("/api/traces")
 
422
  """Baseline grid simulation, no workload"""
423
  @app.post("/api/powerflow")
424
  async def powerflow(req: PowerflowRequest):
425
+ logger.info(f"Powerflow request v={req.substationVoltage}")
426
  try:
427
  dc = _build_dc(scale=0.001, duration_s=5)
428
  grid = _build_grid(req.substationVoltage, "671")
429
  log = _run(dc, grid, req.substationVoltage, "671", 5)
430
+ vs = _voltages(log.grid_states[-1])
431
+ logger.info(f"Powerflow result min={min(vs):.4f} max={max(vs):.4f}")
432
  return {"buses": [{"id": i+1, "voltage": v, "activePower": 0.0,
433
  "reactivePower": 0.0} for i, v in enumerate(vs)],
434
  "lines": []}
435
  except Exception as e:
436
+ logger.exception("Powerflow failed")
437
  raise HTTPException(status_code=500, detail=str(e))
438
 
439
 
 
441
  """Simulate AI workload impact on grid using GPU traces."""
442
  @app.post("/api/llm-impact")
443
  async def llm_impact(req: LLMImpactRequest):
444
+ logger.info(f"Simulation request: {req.modelLabel} bus={req.targetBus}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
+ try:
447
+ loop = asyncio.get_event_loop()
448
+ result = await loop.run_in_executor(_pool, _run_full, req.dict())
449
+ return result
 
 
450
 
451
+ except Exception as e:
452
+ logger.exception("Simulation failed")
453
+ raise HTTPException(status_code=500, detail=str(e))
454
 
455
 
456
 
 
472
 
473
 
474
  if __name__ == "__main__":
475
+ logger.info("=" * 70)
476
+ logger.info(f"Data dir: {_DATA_DIR} ready={_DATA_DIR.exists()}")
 
477
  df = _load_traces_index()
478
  if not df.empty:
479
  models = df["model_label"].unique().tolist()
480
+ logger.info(f"Models: {models}")
481
+ logger.info(f"Traces: {len(df)} configurations")
482
+ logger.info("=" * 70)
483
  uvicorn.run("server:app", host="0.0.0.0", port=8080, workers=1, log_level="info")