File size: 2,481 Bytes
96bb363 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | """Tests for LogisticModel: eval, derivative correctness, and numerical gradient check."""
from __future__ import annotations
import math
from mlenergy_data.modeling import LogisticModel
def test_eval_at_midpoint():
"""At x = x0, sigmoid = 0.5, so y = b0 + L/2."""
lp = LogisticModel(L=100.0, x0=5.0, k=1.0, b0=50.0)
assert abs(lp.eval_x(5.0) - 100.0) < 1e-10
def test_eval_at_extremes():
"""Far from x0, sigmoid approaches 0 or 1."""
lp = LogisticModel(L=100.0, x0=5.0, k=2.0, b0=10.0)
# Very large x -> sigmoid -> 1 -> y -> b0 + L = 110
assert abs(lp.eval_x(100.0) - 110.0) < 1e-5
# Very small x -> sigmoid -> 0 -> y -> b0 = 10
assert abs(lp.eval_x(-100.0) - 10.0) < 1e-5
def test_eval_batch():
"""eval(batch) should equal eval_x(log2(batch))."""
lp = LogisticModel(L=200.0, x0=7.0, k=0.5, b0=30.0)
for batch in [8, 16, 32, 64, 128, 256, 512]:
x = math.log2(batch)
assert abs(lp.eval(batch) - lp.eval_x(x)) < 1e-12
def test_derivative_numerical_gradient():
"""Analytical derivative should match numerical finite-difference gradient."""
lp = LogisticModel(L=150.0, x0=6.0, k=1.5, b0=20.0)
eps = 1e-7
for x in [3.0, 5.0, 6.0, 7.0, 9.0]:
analytical = lp.deriv_wrt_x(x)
numerical = (lp.eval_x(x + eps) - lp.eval_x(x - eps)) / (2 * eps)
assert abs(analytical - numerical) < 1e-4, (
f"Gradient mismatch at x={x}: analytical={analytical:.8f}, numerical={numerical:.8f}"
)
def test_derivative_sign():
"""Positive k and L means derivative is positive (increasing sigmoid)."""
lp = LogisticModel(L=100.0, x0=5.0, k=1.0, b0=0.0)
for x in [3.0, 5.0, 7.0]:
assert lp.deriv_wrt_x(x) > 0
# Negative L means derivative is negative
lp_neg = LogisticModel(L=-100.0, x0=5.0, k=1.0, b0=200.0)
for x in [3.0, 5.0, 7.0]:
assert lp_neg.deriv_wrt_x(x) < 0
def test_derivative_peak_at_midpoint():
"""Derivative is maximized at x = x0."""
lp = LogisticModel(L=100.0, x0=5.0, k=2.0, b0=0.0)
d_mid = lp.deriv_wrt_x(5.0)
d_off = lp.deriv_wrt_x(3.0)
assert d_mid > d_off
def test_numerical_stability_large_input():
"""Should not overflow for large inputs."""
lp = LogisticModel(L=100.0, x0=5.0, k=1.0, b0=0.0)
assert math.isfinite(lp.eval_x(1000.0))
assert math.isfinite(lp.eval_x(-1000.0))
assert math.isfinite(lp.deriv_wrt_x(1000.0))
assert math.isfinite(lp.deriv_wrt_x(-1000.0))
|