Predicting High Uncertainty Events to Train Working Memory ( выжимка )
Artyom Sorokin | Dec 2021
Theory
Memory's Objective
Мы бы хотели чтобы наша память максимизировала следующую сумму:
Для простоты буду писать так:
Memory's Objective
Проблема: учить память предсказывать будущее \(f_t\) на шаге t может быть уже поздно:
Цель обучения памяти:
информация с шага \(t-k\) уже потеряна на шаге \(t\)
Memory's Objective
Чтобы выучится не выкидывать старую информацию придется на каждом шаге оптимизировать память относительно всех будущих шагов:
\(O(T^2)\) по времени
Идея MemUP:
вместо того, чтобы оптимизировать вторую сумму полностью, выберем шаги когда обучение памяти может дать наибольший вклад в предсказание будущего!
все еще нужно обрабатвать всю последовательность длинны \(T\)
Если поменять порядок сумм то получится 1 в 1 как учатся трансформер в RL
Finding Important Moments
насколько память может быть важна для предсказания \(f_t\)
Найдем моменты в будущем, когда выучивание памяти потенциально может принести максимальную пользу
Вообразим, что у нас есть идеальная память \(m^{\theta^*}\), тогда:
мелкая игнорим
оцениваем детектором
Общая схема обучения:
- Учим детектор \(d_{\psi}\) предсказывать \(f_t\) на каждом шаге на основе \(c_t\). Важно уметь давать оценку неопределенности предсказаний детектора \(\hat{H}_{\psi}(f_t| с_t)\).
- Учим память \(g_{\theta}\) для каждого шага t предсказывать будущие события \(U_t\), где память может быть наиболее важна:
\(U_t\) это набор шагов из эпизода для, которых детектор дал наибольшую оценку \(\hat{H}_{\psi}(f_i| с_i)\); \(|U| \ll T\).
При обучении памяти мы используем информацию из будущего, которой не будем владеть во время её применения, поэтому нужна отдельная сетка для обьединения шагов k и t: предиктор
Общая схема обучения:
Optimization
Учим детектор
Нужно чтобы детектор \(d_{\psi}\) умел оценивать энтропию \(H(f_t|c_t)\):
Чтобы оценить \(-log\, p(f_t|c_t)\), достаточно использовать Cross-Entropy Loss:
Мы не можем повлиять на \(H(f_t|c_t)\), поэтому минимизируя CE loss мы минимизируем \(D_{KL}\) между нашей моделькой и настроящим распределением.
если не сработает, будем искать более честную оценку энтропии
Наша оценка энтропии для бедных:
То что CE минимизурует \(D_{KL}\), результат супер известный, можно даже не расписывать в статье.
(1)
оценка энтропии по одному сэмплу или удивление
Учим детектор
Нужно чтобы детектор \(d_{\psi}\) умел оценивать энтропию \(H(f_t|c_t, m_t)\):
Чтобы оценить \(-log\, p(f_t|c_t)\), достаточно использовать Cross-Entropy Loss:
Мы не можем повлиять на \(H(f_t|c_t)\), поэтому минимизируя CE loss мы минимизируем \(D_{KL}\) между нашей моделькой и настроящим распределением.
если не сработает, будем искать более честную оценку энтропии
Наша оценка энтропии для бедных:
оценка энтропии по одному сэмплу или удивление
(1)
Учим память+предиктор
Память \(g_{\theta}\) и предиктор \(q_{\phi}\) должны максимизировать взаимную информацию :
Barber, Agakov (2004) доказали Lower Bound для взаимной информации
(перепишем для нашего случая):
вспоминаем ур.1 между \(CE\) и \(KL\) только для распределения \(p(f_k|m^{\theta}_t, c_k)\):
Т.к.
, значит
получается:
(2)
Учим память+предиктор
Память \(g_{\theta}\) и предиктор \(q_{\phi}\) должны максимизировать взаимную информацию :
Barber, Agakov (2004) доказали Lower Bound для взаимной информации
(перепишем для нашего случая):
от нас не зависит, игнор
NLL loss, минимизируем её и разом обновляем \(\theta\) и \(\phi\)
Итог: \(d_\psi\), \(g_\theta\), \(q_\phi\) учим предсказывать будущее \(f_t\) при помощи NLL loss
Учим память+предиктор
На всякий случай для проверки, можно прочитать следующие 2 статьи из которых взято доказательство для функции memory+predictor:
Architecture
Architecture
-
TrajGenerator создаем последовательности
- TargetCreator: может потрбоваться создавать цели предсказания (RL, self-supervision)
-
UncertaintyDetector: оцениваем неопределенность шагов в траектории
- СontextEncoder: нужен здесь и для предиктора
- EventSelector: Чтобы по неопределенности выбрать события для предсказаний памяти
-
EpisodeBuffer: Выбирает батчи эпизодов, запускает на них UncertaintyDetector, EventSelector
- Возможно нужен отдельный класс чтобы все собрать в батч
- MemoryModule: учится хранить информацию
- Predictor: позволяет учить памят на событиях из будущего.
Чтобы учить память используем всё. Чтобы учить детектор достаточно только первых 4ех
ContextEncoder
* Из данных эпизода нужно сформировать нормальные наблюдения. Причем, когда эпизод уже разбит на роллауты, может быть уже поздно что-то делать.
* Применяется перед использованием Detector'а и Predictor'а (должен быть одинаковый для этой пары)
Context
TrajGenerator
Какие варианты могут быть:
- Просто создавать, как copy-task: CopyGenerator(<аргументы задачи>)
- Загружать из датасета: DatasetLoader(<путь и тд.>)
- RL версия: RLGenerator( env, policy, etc.)
- SequentialRLGenerator(env, policy, etc.)
- ParallelRLGenerator(env, policy, MultiprocessingQueue?)
Поля:
* get_trajs(n?)
* TrajIterator
MemUP Выжимка
By supergriver
MemUP Выжимка
- 410