Replay Buffers

Several variants of replay buffers are included in rlpyt. Options include: n-step returns (computed by the replay buffer), prioritized replay (sum-tree), frame-based observation storage (for memory savings), and replay of sequences.

All buffers are based on pre-allocated a size of memory with leading dimensions [T,B], where B is the expected (and required) corresponding dimension in the input sample batches (which will be the number of parallel environmnets in the sampler), and T is chosen to attain the total requested buffer size. A universal time cursor tracks the position of latest inputs along the T dimension of the buffer, and it wraps automatically. Use of namedarraytuples makes it straightforward to write data of arbitrary structure to the buffer’s next indexes. Further benefits are that pre-allocated storage doesn’t grow and is more easily shared across processes (async mode). But this format does require accounting for which samples are currently invalid due to partial memory overwrite, based on n-step returns or needing to replay sequences. If memory and performance optimization are less of a concern, it might be preferable to write a simpler buffer which, for example, stores a rotating list of complete sequences to replay.

Hint

The implemented replay buffers share a lot of components, and sub-classing with multiple inheritances is used to prevent redundant code. If modifying a replay buffer, it might be easier to first copy all desired components into one monolithic class, and then work from there.

Replay Buffer Components

Base Buffers

class rlpyt.replays.base.BaseReplayBuffer
append_samples(samples)

Add new data to the replay buffer, possibly ejecting old data.

sample_batch(batch_B, batch_T=None)

Returns a data batch, e.g. for training.

class rlpyt.replays.n_step.BaseNStepReturnBuffer(example, size, B, discount=1, n_step_return=1)

Bases: rlpyt.replays.base.BaseReplayBuffer

Stores the most recent data and computes n_step returns. Operations are all vectorized, as data is stored with leading dimensions [T,B]. Cursor is next idx to be written.

For now, Assume all incoming samples are “valid” (i.e. must have mid_batch_reset=True in sampler). Can relax this later by tracking valid for each data sample.

Subclass this with specific batch sampling scheme.

Latest n_step timesteps up to cursor are temporarily invalid because all future empirical rewards not yet sampled (off_backward). The current cursor position is also an invalid sample, because the previous action and previous reward have been overwritten (off_forward).

Input example should be a namedtuple with the structure of data (and one example each, no leading dimensions), which will be input every time samples are appended.

If n_step_return>1, then additional buffers samples_return_ and samples_done_n will also be allocated. n-step returns for a given sample will be stored at that same index (e.g. samples_return_[T,B] will store reward[T,B] + discount * reward[T+1,B], + discount ** 2 * reward[T+2,B],…). done_n refers to whether a done=True signal appears in any of the n-step future, such that the following value should not be bootstrapped.

append_samples(samples)

Write the samples into the buffer and advance the time cursor. Handle wrapping of the cursor if necessary (boundary doesn’t need to align with length of samples). Compute and store returns with newly available rewards.

compute_returns(T)

Compute the n-step returns using the new rewards just written into the buffer, but before the buffer cursor is advanced. Input T is the number of new timesteps which were just written. Does nothing if n-step==1. e.g. if 2-step return, t-1 is first return written here, using reward at t-1 and new reward at t (up through t-1+T from t+T).

class rlpyt.replays.frame.FrameBufferMixin(example, **kwargs)

Like n-step return buffer but expects multi-frame input observation where each new observation has one new frame and the rest old; stores only unique frames to save memory. Samples observation should be shaped: [T,B,C,..] with C the number of frames. Expects frame order: OLDEST to NEWEST.

A special method for replay will be required to piece the frames back together into full observations.

Latest n_steps up to cursor temporarilty invalid because “next” not yet written. Cursor timestep invalid because previous action and reward overwritten. NEW: Next n_frames-1 invalid because observation history frames overwritten.

append_samples(samples)

Appends all samples except for the observation as normal. Only the new frame in each observation is recorded.

class rlpyt.replays.async_.AsyncReplayBufferMixin(*args, **kwargs)

Mixin class which manages the buffer (shared) memory under a read-write lock (multiple-reader, single-writer), for use with the asynchronous runner. Wraps the append_samples(), sample_batch(), and update_batch_priorities() methods. Maintains a universal buffer cursor, communicated asynchronously. Supports multiple buffer-writer processes and multiple replay processes.

Non-Sequence Replays

class rlpyt.replays.non_sequence.n_step.NStepReturnBuffer(example, size, B, discount=1, n_step_return=1)

Bases: rlpyt.replays.n_step.BaseNStepReturnBuffer

Definition of what fields are replayed from basic n-step return buffer.

extract_batch(T_idxs, B_idxs)

From buffer locations [T_idxs,B_idxs], extract data needed for training, including target values at T_idxs + n_step_return. Returns namedarraytuple of torch tensors (see file for all fields). Each tensor has leading batch dimension len(T_idxs)==len(B_idxs), but individual samples are drawn, so no leading time dimension.

extract_observation(T_idxs, B_idxs)

Simply observation[T_idxs,B_idxs]; generalization anticipating frame-based buffer.

class rlpyt.replays.non_sequence.frame.NStepFrameBuffer(example, **kwargs)

Bases: rlpyt.replays.frame.FrameBufferMixin, rlpyt.replays.non_sequence.n_step.NStepReturnBuffer

Special method for re-assembling observations from frames.

extract_observation(T_idxs, B_idxs)

Assembles multi-frame observations from frame-wise buffer. Frames are ordered OLDEST to NEWEST along C dim: [B,C,H,W]. Where done=True is found, the history is not full due to recent environment reset, so these frames are zero-ed.

class rlpyt.replays.non_sequence.uniform.UniformReplay

Replay of individual samples by uniform random selection.

sample_batch(batch_B)

Randomly select desired batch size of samples to return, uses sample_idxs() and extract_batch().

sample_idxs(batch_B)

Randomly choose the indexes of data to return using np.random.randint(). Disallow samples within certain proximity to the current cursor which hold invalid data.

class rlpyt.replays.non_sequence.prioritized.PrioritizedReplay(alpha=0.6, beta=0.4, default_priority=1, unique=False, input_priorities=False, input_priority_shift=0, **kwargs)

Prioritized experience replay using sum-tree prioritization.

The priority tree must configure at instantiation if priorities will be input with samples in append_samples(), by parameter input_priorities=True, else the default value will be applied to all new samples.

append_samples(samples)

Looks for samples.priorities; if not found, uses default priority. Writes samples using super class’s append_samples, and advances matching cursor in priority tree.

sample_batch(batch_B)

Calls on the priority tree to generate random samples. Returns samples data and normalized importance-sampling weights: is_weights=priorities ** -beta

update_batch_priorities(priorities)

Takes in new priorities (i.e. from the algorithm after a training step) and sends them to priority tree as priorities ** alpha; the tree internally remembers which indexes were sampled for this batch.

class rlpyt.replays.non_sequence.time_limit.NStepTimeLimitBuffer(*args, **kwargs)

Bases: rlpyt.replays.non_sequence.n_step.NStepReturnBuffer

For use in e.g. SAC when bootstrapping when env done due to timeout. Expects input samples to include timeout field, and returns timeout and timeout_n similar to done and done_n.

Sequence Replays

class rlpyt.replays.sequence.n_step.SequenceNStepReturnBuffer(example, size, B, rnn_state_interval, batch_T=None, **kwargs)

Bases: rlpyt.replays.n_step.BaseNStepReturnBuffer

Base n-step return buffer for sequences replays. Includes storage of agent’s recurrent (RNN) state.

Use of rnn_state_interval>1 only periodically stores RNN state, to save memory. The replay mechanism must account for the fact that only time-steps with saved RNN state are valid first states for replay. (rnn_state_interval<1 does not store RNN state.)

append_samples(samples)

Special handling for RNN state storage, and otherwise uses superclass’s append_samples().

extract_batch(T_idxs, B_idxs, T)

Return full sequence of each field in agent_inputs (e.g. observation), including all timesteps for the main sequence and for the target sequence in one array; many timesteps will likely overlap, so the algorithm and make sub-sequences by slicing on device, for reduced memory usage.

Enforces that input T_idxs align with RNN state interval.

Uses helper function extract_sequences() to retrieve samples of length T starting at locations [T_idxs,B_idxs], so returned data batch has leading dimensions [T,len(B_idxs)].

class rlpyt.replays.sequence.frame.SequenceNStepFrameBuffer(example, **kwargs)

Bases: rlpyt.replays.frame.FrameBufferMixin, rlpyt.replays.sequence.n_step.SequenceNStepReturnBuffer

Includes special method for extracting observation sequences from a frame-wise buffer, where each time-step includes multiple frames. Each returned sequence will contain many redundant frames (A more efficient way would be to turn the Conv2D into a Conv3D and only return unique frames.)

extract_observation(T_idxs, B_idxs, T)

Observations are re-assembled from frame-wise buffer as [T,B,C,H,W], where C is the frame-history channels, which will have redundancy across the T dimension. Frames are returned OLDEST to NEWEST along the C dimension.

Frames are zero-ed after environment resets.

class rlpyt.replays.sequence.uniform.UniformSequenceReplay

Replays sequences with starting state chosen uniformly randomly.

sample_batch(batch_B, batch_T=None)

Can dynamically input length of sequences to return, by batch_T, else if None will use interanlly set value. Returns batch with leading dimensions [batch_T, batch_B].

sample_idxs(batch_B, batch_T)

Randomly choose the indexes of starting data to return using np.random.randint(). Disallow samples within certain proximity to the current cursor which hold invalid data, including accounting for sequence length (so every state returned in sequence will hold valid data). If the RNN state is only stored periodically, only choose starting states with stored RNN state.

class rlpyt.replays.sequence.prioritized.PrioritizedSequenceReplay(alpha=0.6, beta=0.4, default_priority=1, unique=False, input_priorities=False, input_priority_shift=0, **kwargs)

Prioritized experience replay of sequences using sum-tree prioritization. The size of the sum-tree is based on the number of RNN states stored, since valid sequences must start with an RNN state. Hence using periodic storage with rnn_state_inveral>1 results in a faster tree using less memory. Replay buffer priorities are indexed to the start of the whole sequence to be returned, regardless of whether the initial part is used only as RNN warmup.

Requires batch_T to be set and fixed at instantiation, so that the priority tree has a fixed scheme for which samples are temporarilty invalid due to the looping cursor (the tree must set and propagate 0-priorities for those samples, so dynamic batch_T could require additional tree operations for every sampling event).

Parameter input_priority_shift is used to assign input priorities to a starting time-step which is shifted from the samples input to append_samples(). For example, in R2D1, using replay sequences of 120 time-steps, with 40 steps for warmup and 80 steps for training, we might run the sampler with 40-step batches, and store the RNN state only at the beginning of each batch: rnn_state_interval=40. In this scenario, we would use input_priority_shift=2, so that the input priorities which are provided with each batch of samples are assigned to sequence start-states at the beginning of warmup (shifted 2 entries back in the priority tree). This way, the input priorities can be computed after seeing all 80 training steps. In the meantime, the partially-written sequences are marked as temporarily invalid for replay anyway, according to buffer cursor position and the fixed batch_T replay setting. (If memory and performance optimization are less of a concern, the indexing effort might all be simplified by writing a replay buffer which manages a list of valid trajectories to sample, rather than a monolithic, pre-allocated buffer.)

append_samples(samples)

Like non-sequence prioritized, except also stores RNN state, and advances the priority tree cursor according to the number of RNN states stored (which might be less than overall number of time-steps).

sample_batch(batch_B)

Returns batch with leading dimensions [self.batch_T, batch_B], with each sequence sampled randomly according to priority. (self.batch_T should not be changed).

Priority Tree

class rlpyt.replays.sum_tree.SumTree(T, B, off_backward, off_forward, default_value=1, enable_input_priorities=False, input_priority_shift=0)

Sum tree for matrix of values stored as [T,B], updated in chunks along T dimension, applying to the full B dimension at each update. Priorities represented as first T*B leaves of binary tree. Turns on/off entries in vicinity of cursor position according to “off_backward” (e.g. n_step_return) and “off_forward” (e.g. 1 for prev_action or max(1, frames-1) for frame-wise buffer). Provides efficient sampling from non-uniform probability masses.

Note

Tried single precision (float32) tree, and it sometimes returned samples with priority 0.0, because subtraction during tree cascade left random value larger than the remaining sum; suggest keeping float64.

advance(T, priorities=None)

Cursor advances by T: set priorities to zero in vicinity of new cursor position and turn priorities on for new samples since previous cursor position. Optional param priorities can be None for default, or of dimensions [T, B], or [B] or scalar will broadcast. (Must have enabled input_priorities=True when instantiating the tree.) These will be stored at the current cursor position, meaning these priorities correspond to the current values being added to the buffer, even though their priority might temporarily be set to zero until future advances.

sample(n, unique=False)

Get n samples, with replacement (default) or without. Use np.random.rand() to generate random values with which to descend the tree to each sampled leaf node. Returns T_idxs and B_idxs, and sample priorities.

update_batch_priorities(priorities)

Apply new priorities to tree at the leaf positions where the last batch was returned from the sample() method.

print_tree(level=None)

Print values for whole tree or at specified level.

class rlpyt.replays.sum_tree.AsyncSumTree(*args, **kwargs)

Bases: rlpyt.replays.sum_tree.SumTree

Allocates the tree into shared memory, and manages asynchronous cursor position, for different read and write processes. Assumes that writing to tree values is lock protected elsewhere, i.e. by the replay buffer.

Full Replay Buffer Classes

These are all defined purely as sub-classes with above components.

Non-Sequence Replay

class rlpyt.replays.non_sequence.uniform.UniformReplayBuffer(example, size, B, discount=1, n_step_return=1)

Bases: rlpyt.replays.non_sequence.uniform.UniformReplay, rlpyt.replays.non_sequence.n_step.NStepReturnBuffer

class rlpyt.replays.non_sequence.uniform.AsyncUniformReplayBuffer(*args, **kwargs)

Bases: rlpyt.replays.async_.AsyncReplayBufferMixin, rlpyt.replays.non_sequence.uniform.UniformReplayBuffer

class rlpyt.replays.non_sequence.prioritized.PrioritizedReplayBuffer(alpha=0.6, beta=0.4, default_priority=1, unique=False, input_priorities=False, input_priority_shift=0, **kwargs)

Bases: rlpyt.replays.non_sequence.prioritized.PrioritizedReplay, rlpyt.replays.non_sequence.n_step.NStepReturnBuffer

class rlpyt.replays.non_sequence.prioritized.AsyncPrioritizedReplayBuffer(*args, **kwargs)

Bases: rlpyt.replays.async_.AsyncReplayBufferMixin, rlpyt.replays.non_sequence.prioritized.PrioritizedReplayBuffer

class rlpyt.replays.non_sequence.frame.UniformReplayFrameBuffer(example, **kwargs)

Bases: rlpyt.replays.non_sequence.uniform.UniformReplay, rlpyt.replays.non_sequence.frame.NStepFrameBuffer

class rlpyt.replays.non_sequence.frame.PrioritizedReplayFrameBuffer(alpha=0.6, beta=0.4, default_priority=1, unique=False, input_priorities=False, input_priority_shift=0, **kwargs)

Bases: rlpyt.replays.non_sequence.prioritized.PrioritizedReplay, rlpyt.replays.non_sequence.frame.NStepFrameBuffer

class rlpyt.replays.non_sequence.frame.AsyncUniformReplayFrameBuffer(*args, **kwargs)

Bases: rlpyt.replays.async_.AsyncReplayBufferMixin, rlpyt.replays.non_sequence.frame.UniformReplayFrameBuffer

class rlpyt.replays.non_sequence.frame.AsyncPrioritizedReplayFrameBuffer(*args, **kwargs)

Bases: rlpyt.replays.async_.AsyncReplayBufferMixin, rlpyt.replays.non_sequence.frame.PrioritizedReplayFrameBuffer

class rlpyt.replays.non_sequence.time_limit.TlUniformReplayBuffer(*args, **kwargs)

Bases: rlpyt.replays.non_sequence.uniform.UniformReplay, rlpyt.replays.non_sequence.time_limit.NStepTimeLimitBuffer

class rlpyt.replays.non_sequence.time_limit.TlPrioritizedReplayBuffer(alpha=0.6, beta=0.4, default_priority=1, unique=False, input_priorities=False, input_priority_shift=0, **kwargs)

Bases: rlpyt.replays.non_sequence.prioritized.PrioritizedReplay, rlpyt.replays.non_sequence.time_limit.NStepTimeLimitBuffer

class rlpyt.replays.non_sequence.time_limit.AsyncTlUniformReplayBuffer(*args, **kwargs)

Bases: rlpyt.replays.async_.AsyncReplayBufferMixin, rlpyt.replays.non_sequence.time_limit.TlUniformReplayBuffer

class rlpyt.replays.non_sequence.time_limit.AsyncTlPrioritizedReplayBuffer(*args, **kwargs)

Bases: rlpyt.replays.async_.AsyncReplayBufferMixin, rlpyt.replays.non_sequence.time_limit.TlPrioritizedReplayBuffer

Sequence Replay

class rlpyt.replays.sequence.uniform.UniformSequenceReplayBuffer(example, size, B, rnn_state_interval, batch_T=None, **kwargs)

Bases: rlpyt.replays.sequence.uniform.UniformSequenceReplay, rlpyt.replays.sequence.n_step.SequenceNStepReturnBuffer

class rlpyt.replays.sequence.uniform.AsyncUniformSequenceReplayBuffer(*args, **kwargs)

Bases: rlpyt.replays.async_.AsyncReplayBufferMixin, rlpyt.replays.sequence.uniform.UniformSequenceReplayBuffer

class rlpyt.replays.sequence.prioritized.PrioritizedSequenceReplayBuffer(alpha=0.6, beta=0.4, default_priority=1, unique=False, input_priorities=False, input_priority_shift=0, **kwargs)

Bases: rlpyt.replays.sequence.prioritized.PrioritizedSequenceReplay, rlpyt.replays.sequence.n_step.SequenceNStepReturnBuffer

class rlpyt.replays.sequence.prioritized.AsyncPrioritizedSequenceReplayBuffer(*args, **kwargs)

Bases: rlpyt.replays.async_.AsyncReplayBufferMixin, rlpyt.replays.sequence.prioritized.PrioritizedSequenceReplayBuffer

class rlpyt.replays.sequence.frame.UniformSequenceReplayFrameBuffer(example, **kwargs)

Bases: rlpyt.replays.sequence.uniform.UniformSequenceReplay, rlpyt.replays.sequence.frame.SequenceNStepFrameBuffer

class rlpyt.replays.sequence.frame.PrioritizedSequenceReplayFrameBuffer(alpha=0.6, beta=0.4, default_priority=1, unique=False, input_priorities=False, input_priority_shift=0, **kwargs)

Bases: rlpyt.replays.sequence.prioritized.PrioritizedSequenceReplay, rlpyt.replays.sequence.frame.SequenceNStepFrameBuffer

class rlpyt.replays.sequence.frame.AsyncUniformSequenceReplayFrameBuffer(*args, **kwargs)

Bases: rlpyt.replays.async_.AsyncReplayBufferMixin, rlpyt.replays.sequence.frame.UniformSequenceReplayFrameBuffer

class rlpyt.replays.sequence.frame.AsyncPrioritizedSequenceReplayFrameBuffer(*args, **kwargs)

Bases: rlpyt.replays.async_.AsyncReplayBufferMixin, rlpyt.replays.sequence.frame.PrioritizedSequenceReplayFrameBuffer