coopihczoo.teaching.scripts_to_sort.behavioral_cloning_originalΒΆ

Functions

check_for_correct_spaces

Checks that the environment has same spaces as provided ones.

dataclass_quick_asdict

Extract dataclass to items using dataclasses.fields + dict comprehension.

flatten_trajectories

Flatten a series of trajectory dictionaries into arrays.

generate_trajectories

Generate trajectory dictionaries from a policy and an environment.

make_min_episodes

Terminate after collecting n episodes of data.

make_min_timesteps

Terminate at the first episode after collecting n timesteps of data.

make_sample_until

Returns a termination condition sampling for a number of timesteps and episodes.

rollout

Generate policy rollouts.

transitions_collate_fn

Custom torch.utils.data.DataLoader collate_fn for TransitionsMinimal.

unwrap_traj

Uses RolloutInfoWrapper-captured obs and rews to replace fields.

Classes

BC

Behavioral cloning (BC).

ConstantLRSchedule

A callable that returns a constant learning rate.

EpochOrBatchIteratorWithProgress

Wraps DataLoader so that all BC batches can be processed in one for-loop.

FeedForward32Policy

A feed forward policy network with two hidden layers of 32 units.

RolloutInfoWrapper

Add the entire episode's rewards and observations to info at episode end.

Trajectory

A trajectory, e.g.

TrajectoryAccumulator

Accumulates trajectories step-by-step.

TrajectoryWithRew

A Trajectory that additionally includes reward information.

Transitions

A batch of obs-act-obs-done transitions.

TransitionsMinimal

A Torch-compatible Dataset of obs-act transitions.