Improve week06/a2c-optional
dniku opened this issue · comments
via @q0o0p
Закончила проверять week6 (но некоторые всё ещё досдают)
В ШАД эту домашку сдали 108 человек, из них 26 сдали ActorCritic. Один прислал Кунг-Фу из week 08. В ВШЭ домашку сдали 27 человек, один из них ActorCritic. Многие сдали и то и другое. Трое сдали ActorCritic на tensorflow (двое из них на TF2). Двое сдали reinforce на tensorflow.
(в прошлом году, судя по таблице, ActorCritic сдали ~6 человек).
На проверку ШАДа + ВШЭ ушло 55 часов чистого времени, из этого 35 часов на 26 работ ActorCritic. В среднем 10 минут на reinforce и 1 час 20 минут на ActorCritic.
Почему на ActorCritic так много времени:
В целом эта домашка, в отличие от, например, ActorCritic-а в 8й неделе, почти полностью "творческая", то есть студент всё сам пишет, а не заполняет однострочные "<YOUR_CODE>". Как следствие, очень много вариантов, как в глобальном смысле построить взаимодействие классов, так и локально как что-то реализовать, и как задать константы. И столько же много ошибок.
Какие были ошибки:
В модели:
- нет нелинейностей между свёртками и после;
- функция активации поверх головы;
- (странность) объединение двух голов в один Linear слой с последующим разрезанием output-а;
- делают огромные слои, которые долго учатся;
- линейный слой перед головами делают для каждой головы свой, т.е. в 2 раза больше весов и теряется мультитаск;
- две полностью независимых сетки ValueNet и Policy;
В Policy:
- вместо transpose/permute вызывают reshape и массив заполняется не в том порядке (и свёртки теперь действуют не на рядом стоящие пиксели, а на что попало);
- сэмплим действие (целое число от 0 до 6), а потом torch.clamp его в диапазон от -1 до 1, то есть остаётся только два возможных действия - 0 (ничего не делать) и 1 (стрелять);
- возвращение log_probs_for_actions в качестве log_probs, и из-за этого ошибки в энтропии в дальнейшем;
- полное отключение от градиентов;
В ComputeValueTargets:
- инициализация последнего value_target нулём, недисконтированным значением "после последнего", потеря последнего реворда и ещё много вариантов разной фигни;
- не те знаки в ресетах (явно или в torch.where из-за неправильного понимания аргументов);
- не та логика ресетов, напр. зануление реворда на текущем шаге или пропускание информации через ресет дальше, с занулением только на самом шаге ресета;
- (недочёт) '.insert' в цикле и квадратичная сложность;
- (недочёт) перевод из numpy обратно в торч, чтобы считать градиенты (которые во-первых не нужны, во-вторых всё равно не дойдут до модели);
- заполняют массив с конца (что нормально), но забывают перевернуть его потом;
- value от latest_observation на каждом шаге участвует с одним и тем же коэффициентом (либо gamma, либо gamma ** traj_len);
- путают размерности (батч с шагами);
- вообще не используют реворды или latest_observation;
В лоссах:
- идут градиенты через value_targets;
- идут градиенты через values в policy loss;
- нет минуса в формуле энтропии или там, где она входит в лосс, или вообще энтропии нет;
- энтропия считается только по реально совершённым действиям;
- вместо mean берут сумму и делят на количество сред (забывают про шаги в средах);
- в policy_loss перед mean считают сумму, в результате mean берётся от одного числа;
- в policy_loss берут все возможные действия, как в энтропии, без учёта реально случившихся в траектории действий;
- прочие случаи, когда градиенты не текут там, где должны, и текут там, где не должны;
В целом:
- ошибки в torch.optim.lr_scheduler.LambdaLR, из-за которых лёрнинг рейт совсем не такой, как предполагают студенты;
- вместо 255 делят картинку на максимальное значение по текущему батчу, т.е. всегда на разное;
- численная нестабильность типа logprobs = tf.log(probs);
- optimizer.step() вызывается сразу после optimizer.zero_grad();
- разные питоновые неоптимальности;
По графикам:
- вывод ревордов не mean_100, а просто за один шаг или сессию;
- вывод клипнутых ревордов;
- неправильная реализация moving average и полный бред на графике;
- прочий рандом, из-за которого сразу не скажешь, что по оси x;
И в чате, и в Anytask студенты жаловались, что нет чёткой инструкции, как выводить график реворда.
Тем, кто делал на pytorch или на TF2, tensorflow_summaries не подходило.
Ответ "... mean 100 ..." студентов не удовлетворяет, им надо функции, которые это выводят. Совет "Бери питон и выводи" (как я делала в своё время) им не подходит. И правильно. Даже если они согласятся это делать сами, многие ошибутся.
Я скинула в чат PR с обновлённым atari_wrappers с функциональностью для вывода реворда без TF.
Некоторые это использовали в домашке, кто-то использовал TF, кто-то сам выводил. Но чат читают не все, а за обновлениями репозитория следит ещё меньше человек. И многие выводили реворд неправильно.
То есть, нельзя даже посмотреть на работу и провести sanity check по графику реворда. Это проблема и для студента, и для проверяющего.
То есть, при проверке, чтобы понять, что код работает (не тот график - всё равно что нет output'a), надо либо добавлять построение графика и запускать код, либо очень долго и внимательно вчитываться во все участки кода. Хорошо, если ошибки есть, тогда их можно найти. Но если их нет, то нет и гарантии, что их действительно нет, а не просто я их не заметила.
Было много работ, где с виду всё правильно, но реворд плохой, и надо было долго вчитываться во всё по нескольку раз и дебажить, чтобы найти, в чём проблема.
И, даже если всё верно, всё равно что-нибудь да есть, т.к. задание большое. Вот у человека всё верно, а он говорит, что-то tensorflow медленно работает, помогите найти почему.
Из этого всего следует, что потратить сильно меньше времени на ActorCritic невозможно.
Это я сейчас пишу для того, чтобы поднять этот вопрос на будущее. Если в следующем году студенты ШАДа будут такие же молодцы и сдадут столько же (или больше) ActorCritic'ов, очень вероятно, что у меня (или у того, кто будет это проверять) не будет столько времени. И придётся понизить качество проверки до неприемлемого уровня.
Как этого избежать:
- atari_wrappers уже обновлён для вывода ревордов, в следующем году все студенты это увидят. Возможно, надо в тексте ноутбука это как-то лучше описать;
- Больше ассертов. Добавить ассерт на ComputeValueTargets. Подумать, куда ещё можно добавить ассерты;
- надо приготовить и описать в ноутбуке примерные диапазоны референсных значений для всего: длина сессий, значения ревордов клипнутые и нет, и так далее;
- В ноутбуке документировать формат trajectory из runners.py;
- предлагайте ваши идеи...
В ноутбуке документировать формат trajectory из runners.py;
Когда я сдавал это задание, у меня была вот такая функция
def print_signature(d, indent=0):
for key, value in d.items():
print(' ' * indent + f'{key:<20}: ', end='')
if isinstance(value, list):
print('list')
print_signature(value[0], indent=indent+4)
if isinstance(value, np.ndarray):
print(f'ndarray {value.shape!s:<25} {value.dtype!s:<10}')
elif isinstance(value, torch.Tensor):
print(f'tensor {value.shape!s:<25} {value.dtype!s:<10}')
elif isinstance(value, dict):
print('dict')
print_signature(value, indent=indent+4)
else:
print(type(value))
print_signature(runner.get_next())
Пример вывода:
actions : ndarray (5, 8) int64
logits : tensor torch.Size([5, 8, 6]) torch.float32
log_probs : tensor torch.Size([5, 8, 6]) torch.float32
probs : tensor torch.Size([5, 8, 6]) torch.float32
values : tensor torch.Size([5, 8]) torch.float32
observations : ndarray (5, 8, 84, 84, 4) uint8
rewards : ndarray (5, 8) float64
resets : ndarray (5, 8) bool
state : dict
latest_observation : ndarray (8, 84, 84, 4) uint8
env_steps : <class 'int'>
value_targets : tensor torch.Size([5, 8]) torch.float32
Related: #181
Сейчас студент прислал пожелания на следующий год:
- пояснить в преамбуле, что играем 8 параллельных игр по 5 шагов.
- пояснить картинкой, что является обёрткой над чем (продирался целый день, прежде чем понял что авторы пытались донести всеми прототипами)
- подправить формулку для ComputeValueTargets
- добавить прототип get_entropy() в классе A2C
- неплохо бы сказать, что 160 - это плохо, но если первые 500 000 / 40 итераций не учится, то можно попробовать перезапустить (вообще неочевидно=))