Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
438 changes: 438 additions & 0 deletions 4-atari-hard/2-go-explore.py

Large diffs are not rendered by default.

372 changes: 372 additions & 0 deletions 4-atari-hard/3-robustify.py

Large diffs are not rendered by default.

152 changes: 152 additions & 0 deletions 4-atari-hard/env_go_explore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Go-Explore env setup (restore-based exploration on raw gymnasium ALE).

Separate plumbing from this folder's `env.py` (the PPO/RND envpool stack):
Go-Explore's exploration phase needs the emulator's save/restore API
(ale.cloneState / restoreState), which envpool does not expose. Each
(worker) process owns a single raw ALE env built by `make_restore_env`.
The harness binds promotion markers to the script PLUS the sibling modules
it actually imports, so 2-go-explore.py is hashed with THIS file, not env.py.

Protocol (Ecoffet et al. 2019/2021, exploration phase): fully deterministic —
frameskip 4, NO sticky actions, no no-ops, seed 0. Stochasticity only enters
in the (separate, later) robustification phase. The TimeLimit wrapper is
stripped (`.unwrapped`): its step counter is meaningless when episodes are
entered mid-trajectory via state restore.

★ Verified ALE pitfall (this machine, ale-py 0.11.2): right after
`restoreState`, `getRAM()` / screen reads still return the PRE-restore values
until the next `act()`. Callers must therefore derive cell keys only from
frames returned by `env.step()`, never from immediate post-restore reads.
"""
import argparse
import json
import os
import statistics
import time

import torch # checkpoint serialization only — there is no neural net here


def _atomic_save(state, path):
"""tmp -> rename so a crash mid-write never corrupts the checkpoint."""
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp = f"{path}.tmp"
torch.save(state, tmp)
os.replace(tmp, path)


class RunLogger:
"""Optional run-directory outputs: metrics.jsonl, periodic / milestone /
best checkpoints, resume, and a final.json summary. Inert when run_dir is
None, so the script still runs standalone.

Same contract as 4-atari-hard/env.py with one change: milestone
checkpoints fire every 50M frames instead of 5M — a Go-Explore checkpoint
carries the whole archive (~0.5 GB at 50k cells), and a 500M-step run
would otherwise pile up 100 of them."""

MILESTONE_EVERY = 50_000_000

def __init__(self, run_dir, ckpt_every):
self.dir = run_dir
self.ckpt_dir = os.path.join(run_dir, "ckpt") if run_dir else None
self.ckpt_every = ckpt_every
if self.ckpt_dir:
os.makedirs(self.ckpt_dir, exist_ok=True)
self.f = open(os.path.join(run_dir, "metrics.jsonl"), "a", buffering=1) if run_dir else None
self.t0, self.last_frames = time.time(), 0
self.ckpt_last, self.ms_last, self.best = 0, 0, float("-inf")

def log(self, frames, scalars):
"""Append one structured row (frames + sps + caller's scalars) to metrics.jsonl."""
if not self.f:
return
now = time.time()
sps = (frames - self.last_frames) / max(now - self.t0, 1e-9)
self.f.write(json.dumps({"ts": round(now, 1), "frames": frames, "sps": round(sps, 1), **scalars}) + "\n")
self.t0, self.last_frames = now, frames

def resolve_resume(self, resume_arg):
"""'auto' -> run_dir/ckpt/latest.pt, else a path, else None."""
if resume_arg == "auto" and self.ckpt_dir:
cand = os.path.join(self.ckpt_dir, "latest.pt")
return cand if os.path.exists(cand) else None
if resume_arg and resume_arg != "auto":
return resume_arg if os.path.exists(resume_arg) else None
return None

def checkpoint(self, frames, state_fn, gate=None):
"""Periodic 'latest', 50M-step milestone, and best-gate checkpoints.
state_fn() builds the dict only when a save actually happens."""
if not self.ckpt_dir or not self.ckpt_every:
return
if frames - self.ckpt_last >= self.ckpt_every:
_atomic_save(state_fn(), os.path.join(self.ckpt_dir, "latest.pt"))
self.ckpt_last = frames
if frames - self.ms_last >= self.MILESTONE_EVERY:
_atomic_save(state_fn(), os.path.join(self.ckpt_dir, f"step_{frames // 1_000_000}M.pt"))
self.ms_last = frames
if gate is not None and gate > self.best:
self.best = gate
_atomic_save(state_fn(), os.path.join(self.ckpt_dir, "best.pt"))

def finalize(self, frames, game_returns, state_fn, k=100):
"""Final 'latest' checkpoint + a final.json result summary."""
if self.ckpt_dir:
_atomic_save(state_fn(), os.path.join(self.ckpt_dir, "latest.pt"))
if self.dir:
tail = [float(x) for x in game_returns[-k:]]
with open(os.path.join(self.dir, "final.json"), "w") as fh:
json.dump({"frames_total": frames, "frames_unit": "agent_steps",
"gate_metric": "game_return_mean_lastK", "K": k,
"value_mean": statistics.fmean(tail) if tail else float("nan"),
"value_std": statistics.pstdev(tail) if len(tail) > 1 else 0.0,
"episodes_counted": len(tail)}, fh, indent=1)
if self.f:
self.f.close()


# Gymnasium / ALE ids. The "_goexplore" key marks a distinct benchmark
# protocol (deterministic, no sticky) — never cross-compare with the
# sticky-action `montezuma` numbers elsewhere in this repo.
ENV_IDS = {
"montezuma_goexplore": "ALE/MontezumaRevenge-v5",
}


def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--env", choices=list(ENV_IDS), default="montezuma_goexplore",
help="which game to explore")
# --- reproducibility / run-management flags (harness run contract) ---
p.add_argument("--seed", type=int, default=None,
help="seed for the action RNG (the emulator itself is deterministic)")
p.add_argument("--total-frames", type=int, default=None,
help="override the in-file TOTAL_FRAMES budget (agent steps actually executed)")
p.add_argument("--n-workers", type=int, default=None,
help="override the in-file N_WORKERS (parallel explorer processes)")
p.add_argument("--run-dir", type=str, default=None,
help="run directory: write metrics.jsonl / ckpt / final.json here")
p.add_argument("--ckpt-every", type=int, default=None,
help="periodic checkpoint interval in agent steps (resume-safe)")
p.add_argument("--resume", type=str, default=None,
help="'auto' (run-dir/ckpt/latest.pt) or a checkpoint path")
return p.parse_args()


def make_restore_env(env_key):
"""Single raw ALE env with clone/restore access.

Imports live here (not module top) so harness-side tests can stub this
module without pulling in ale_py. Returns the unwrapped env: TimeLimit's
step counter would spuriously truncate restore-based exploration, and
OrderEnforcing rejects step-after-restore patterns."""
import ale_py
import gymnasium as gym
gym.register_envs(ale_py)
env = gym.make(ENV_IDS[env_key], frameskip=4,
repeat_action_probability=0.0, # deterministic — Phase 1 requirement
obs_type="grayscale").unwrapped
env.reset(seed=0) # canonical deterministic start; variation comes from action RNGs
assert env.spec.kwargs.get("repeat_action_probability", None) == 0.0
return env
229 changes: 229 additions & 0 deletions 4-atari-hard/env_robustify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
"""Robustification env plumbing — the backward algorithm of Go-Explore /
Salimans & Chen 2018 (arXiv:1812.03381), distilling a single demo into a
policy that works under sticky actions.

Faithful-but-small port of openai/atari-reset (+ the uber-research fork the
Nature paper used). Two pieces:

* `ReplayResetEnv` wraps one raw gymnasium ALE env. Each episode RESTORES to a
point along the demo (`starting_point`) and the agent plays forward from
there under sticky actions. The score counter is seeded with the demo's raw
return up to that point, so "did the agent do as well as the demo from here"
is a single comparison `score >= returns[-1]`.

* `ResetManager` owns the curriculum: starting points are staggered across the
worker pool near the demo's end and marched BACKWARD as the agent succeeds
(and nudged forward when it collapses). `max_starting_point -> 0` means the
policy now plays the whole game from reset — the real progress metric.

Design notes (verified against atari-reset wrappers.py / ppo.py):
1. Success = raw score (incl. demo prefix) >= demo's full return, minus an
allowed deficit. Move rule (code, not paper): new max start = first index
where cumsum(success_rate)/window >= move_threshold; else +nudge forward.
2. lag kill: stay within `allowed_lag` steps of the demo's pace, compared to
a windowed-min of returns (so faithful play through a negative reward
isn't falsely killed).
3. success kill: once as-good-as-demo, run exp(U(0,1)*7) extra steps then end
with `random_reset=True` — the trainer masks GAE across this artificial
boundary and the random length stops the agent timing the cutoff.
4. warm-up replay of the last `reset_steps_ignored` demo actions through the
step path warms the recurrent state; those transitions are `invalid` and
masked from every loss.
5. trained WITH sticky actions (Go-Explore, not S&C deterministic) so the
policy is robust to the eval protocol; 0-30 no-ops when starting at reset.
"""
import pickle

import numpy as np


class StickyActionEnv:
"""repeat_action_probability applied BELOW frameskip — sticky at the raw
action level, the standard v5 stochasticity. We build the ALE env with
sticky 0 and add it here so the demo replay (which must be deterministic)
can bypass it."""

def __init__(self, p=0.25):
self.p = p
self.last = 0

def reset(self):
self.last = 0

def filter(self, action, rng):
if rng.random() < self.p:
return self.last
self.last = action
return action


class ReplayResetEnv:
"""One raw ALE env that starts episodes from demo states. Not a gym env —
the vectorized loop in 3-robustify.py drives it directly."""

def __init__(self, demo, seed, *, sticky=0.25, allowed_lag=50,
allowed_score_deficit=0, reset_steps_ignored=0,
inc_entropy_threshold=100, noop_max=30, max_steps=400_000):
import ale_py
import gymnasium as gym
gym.register_envs(ale_py)
self.env = gym.make(demo["env_id"], frameskip=4,
repeat_action_probability=0.0, # we add sticky ourselves
obs_type="grayscale").unwrapped
self.ale = self.env.ale
self.actions = demo["actions"]
self.rewards = demo["rewards"]
self.returns = demo["returns"] # cumulative raw, return-to-here
self.total_return = float(self.returns[-1])
self.checkpoints = demo["checkpoints"]
self.ckpt_nr = demo["checkpoint_action_nr"]
self.n = len(self.actions)
self.sticky = StickyActionEnv(sticky) if sticky > 0 else None
self.allowed_lag = allowed_lag
self.allowed_score_deficit = allowed_score_deficit
self.reset_steps_ignored = reset_steps_ignored
self.inc_entropy_threshold = inc_entropy_threshold
self.noop_max = noop_max
self.max_steps = max_steps
self.rng = np.random.default_rng(seed)
self.starting_point = self.n - 1
self.frac_sample = 0.2

# --- frame preprocessing: 105x80 grayscale (atari-reset uses RGB; grayscale
# keeps us light and matches the rest of this repo). 4-stack handled in
# the trainer. Returns uint8 (105, 80). ---
def _frame(self):
import cv2
g = self.ale.getScreenGrayscale()
return cv2.resize(g, (80, 105), interpolation=cv2.INTER_AREA)

def _restore_to(self, nr):
"""Restore the latest checkpoint at or before nr, replay demo actions
up to nr (no sticky — deterministic), return the post-restore frame
from a real act (never a stale post-restore read)."""
ci = int(np.searchsorted(self.ckpt_nr, nr, side="right") - 1)
ci = max(ci, 0)
self.ale.restoreState(pickle.loads(self.checkpoints[ci]))
replay_from = int(self.ckpt_nr[ci])
last_frame = None
for i in range(replay_from, nr):
self.ale.act(int(self.actions[i]))
last_frame = None # frames during pure replay are not needed
return last_frame

def reset(self):
# per-episode starting point: 0.8 at the pinned point, 0.2 uniform tail
if self.rng.random() < self.frac_sample:
nr = int(self.rng.integers(self.starting_point, self.n))
else:
nr = self.starting_point
if self.sticky:
self.sticky.reset()

if nr <= 0:
self.env.reset(seed=int(self.rng.integers(2 ** 31)))
for _ in range(int(self.rng.integers(self.noop_max + 1))):
self.ale.act(0)
self.score = 0.0
self.action_nr = 0
self.start_nr = 0
else:
warm = max(nr - self.reset_steps_ignored, 0)
self._restore_to(warm)
self.score = float(self.returns[warm - 1]) if warm > 0 else 0.0
self.action_nr = warm
self.start_nr = nr # success/entropy measured against the true start
self.extra = 0
# post-restore screen reads are STALE until the next act — take one real
# NOOP to get a valid frame (not counted toward score/pace).
frame, _ = self._step_raw(0, bookkeep=False)
return frame

def _step_raw(self, action, *, bookkeep=True):
a = self.sticky.filter(action, self.rng) if self.sticky else action
r = self.ale.act(int(a))
if bookkeep:
self.score += float(r)
self.action_nr += 1
return self._frame(), float(r)

def step(self, action):
frame, raw_r = self._step_raw(action)
info = {"raw_reward": raw_r}
done = False

# success: as good as the demo from here
if self.extra == 0 and self.score >= self.total_return - self.allowed_score_deficit:
self.extra = int(np.exp(self.rng.random() * 7)) # 1..1096
if self.extra > 0:
self.extra -= 1
if self.extra == 0:
done = True
info["random_reset"] = True
info["as_good_as_demo"] = True

# lag kill: fell behind the demo's pace (windowed-min, deficit-aware)
t = self.action_nr
if not done and t > self.allowed_lag and t < self.n:
lo = max(t - self.allowed_lag, 0)
hi = min(t + self.allowed_lag, self.n)
threshold = float(self.returns[lo:hi].min()) - self.allowed_score_deficit
if self.score < threshold:
done = True

if self.ale.game_over() or self.action_nr - self.start_nr >= self.max_steps:
done = True
info["increase_entropy"] = (self.action_nr < self.start_nr + self.inc_entropy_threshold)
return frame, np.sign(raw_r), done, info # clipped reward to the agent


class ResetManager:
"""Owns the shared curriculum across N envs. The trainer calls assign() once
to stagger starting points, and update() each time a batch of episodes
finishes to march max_starting_point backward."""

def __init__(self, demo, n_envs, *, move_threshold=0.1, nudge=100, window=None):
self.n = len(demo["actions"])
self.n_envs = n_envs
self.move_threshold = move_threshold
self.nudge = nudge
# window = the span of staggered starting points (atari-reset nrstartsteps).
# The move target is move_threshold*window of cumulative success mass.
self.window = window or max(n_envs, 32)
self.max_starting_point = self.n - 1
self.max_max = self.n - 1
# latest success-rate per starting-point index
self.success = np.zeros(self.n + 1, dtype=np.float64)

def assign(self, envs):
"""Stagger envs across a window below max_starting_point."""
per = max(self.window // max(self.n_envs, 1), 1)
for i, e in enumerate(envs):
e.starting_point = max(self.max_starting_point - i * per, 0)

def record(self, starting_point, success):
# exponential-ish freshening: latest wins (atari-reset keeps last rate)
self.success[min(starting_point, self.n)] = float(success)

def update(self, envs):
"""Move rule (atari-reset ResetManager.proc_infos): forward-cumsum the
per-index success rates from index 0; the new max starting point is the
FIRST index where the cumulative mass reaches move_threshold*window —
i.e. march back as far as the practiced success band supports, no
further. If the mass is never reached (success collapsed), nudge the
curriculum forward (easier) by `nudge`."""
tail = self.success[: self.max_starting_point + 1]
csum = np.cumsum(tail) # forward: mass accumulated up to each index
hits = np.argwhere(csum >= self.move_threshold * self.window)
if len(hits):
new_max = int(hits[0][0]) # earliest index reaching the mass
self.max_starting_point = max(min(new_max, self.max_starting_point), 0)
else:
self.max_starting_point = min(self.max_starting_point + self.nudge, self.max_max)
self.assign(envs)
return self.max_starting_point


def load_demo(path):
with open(path, "rb") as f:
return pickle.load(f)
Loading