coopihczoo.teaching.scripts_to_sort.behavioral_cloning_originalΒΆ
Functions
Checks that the environment has same spaces as provided ones. |
|
Extract dataclass to items using dataclasses.fields + dict comprehension. |
|
Flatten a series of trajectory dictionaries into arrays. |
|
Generate trajectory dictionaries from a policy and an environment. |
|
Terminate after collecting n episodes of data. |
|
Terminate at the first episode after collecting n timesteps of data. |
|
Returns a termination condition sampling for a number of timesteps and episodes. |
|
Generate policy rollouts. |
|
Custom torch.utils.data.DataLoader collate_fn for TransitionsMinimal. |
|
Uses RolloutInfoWrapper-captured obs and rews to replace fields. |
Classes
Behavioral cloning (BC). |
|
A callable that returns a constant learning rate. |
|
Wraps DataLoader so that all BC batches can be processed in one for-loop. |
|
A feed forward policy network with two hidden layers of 32 units. |
|
Add the entire episode's rewards and observations to info at episode end. |
|
A trajectory, e.g. |
|
Accumulates trajectories step-by-step. |
|
A Trajectory that additionally includes reward information. |
|
A batch of obs-act-obs-done transitions. |
|
A Torch-compatible Dataset of obs-act transitions. |