yandexdataschool / Practical_RL

A course in reinforcement learning in the wild

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Improve week06/a2c-optional

dniku opened this issue · comments

via @q0o0p

Закончила проверять week6 (но некоторые всё ещё досдают)

В ШАД эту домашку сдали 108 человек, из них 26 сдали ActorCritic. Один прислал Кунг-Фу из week 08. В ВШЭ домашку сдали 27 человек, один из них ActorCritic. Многие сдали и то и другое. Трое сдали ActorCritic на tensorflow (двое из них на TF2). Двое сдали reinforce на tensorflow.
(в прошлом году, судя по таблице, ActorCritic сдали ~6 человек).

На проверку ШАДа + ВШЭ ушло 55 часов чистого времени, из этого 35 часов на 26 работ ActorCritic. В среднем 10 минут на reinforce и 1 час 20 минут на ActorCritic.

Почему на ActorCritic так много времени:

В целом эта домашка, в отличие от, например, ActorCritic-а в 8й неделе, почти полностью "творческая", то есть студент всё сам пишет, а не заполняет однострочные "<YOUR_CODE>". Как следствие, очень много вариантов, как в глобальном смысле построить взаимодействие классов, так и локально как что-то реализовать, и как задать константы. И столько же много ошибок.

Какие были ошибки:

В модели:

  • нет нелинейностей между свёртками и после;
  • функция активации поверх головы;
  • (странность) объединение двух голов в один Linear слой с последующим разрезанием output-а;
  • делают огромные слои, которые долго учатся;
  • линейный слой перед головами делают для каждой головы свой, т.е. в 2 раза больше весов и теряется мультитаск;
  • две полностью независимых сетки ValueNet и Policy;

В Policy:

  • вместо transpose/permute вызывают reshape и массив заполняется не в том порядке (и свёртки теперь действуют не на рядом стоящие пиксели, а на что попало);
  • сэмплим действие (целое число от 0 до 6), а потом torch.clamp его в диапазон от -1 до 1, то есть остаётся только два возможных действия - 0 (ничего не делать) и 1 (стрелять);
  • возвращение log_probs_for_actions в качестве log_probs, и из-за этого ошибки в энтропии в дальнейшем;
  • полное отключение от градиентов;

В ComputeValueTargets:

  • инициализация последнего value_target нулём, недисконтированным значением "после последнего", потеря последнего реворда и ещё много вариантов разной фигни;
  • не те знаки в ресетах (явно или в torch.where из-за неправильного понимания аргументов);
  • не та логика ресетов, напр. зануление реворда на текущем шаге или пропускание информации через ресет дальше, с занулением только на самом шаге ресета;
  • (недочёт) '.insert' в цикле и квадратичная сложность;
  • (недочёт) перевод из numpy обратно в торч, чтобы считать градиенты (которые во-первых не нужны, во-вторых всё равно не дойдут до модели);
  • заполняют массив с конца (что нормально), но забывают перевернуть его потом;
  • value от latest_observation на каждом шаге участвует с одним и тем же коэффициентом (либо gamma, либо gamma ** traj_len);
  • путают размерности (батч с шагами);
  • вообще не используют реворды или latest_observation;

В лоссах:

  • идут градиенты через value_targets;
  • идут градиенты через values в policy loss;
  • нет минуса в формуле энтропии или там, где она входит в лосс, или вообще энтропии нет;
  • энтропия считается только по реально совершённым действиям;
  • вместо mean берут сумму и делят на количество сред (забывают про шаги в средах);
  • в policy_loss перед mean считают сумму, в результате mean берётся от одного числа;
  • в policy_loss берут все возможные действия, как в энтропии, без учёта реально случившихся в траектории действий;
  • прочие случаи, когда градиенты не текут там, где должны, и текут там, где не должны;

В целом:

  • ошибки в torch.optim.lr_scheduler.LambdaLR, из-за которых лёрнинг рейт совсем не такой, как предполагают студенты;
  • вместо 255 делят картинку на максимальное значение по текущему батчу, т.е. всегда на разное;
  • численная нестабильность типа logprobs = tf.log(probs);
  • optimizer.step() вызывается сразу после optimizer.zero_grad();
  • разные питоновые неоптимальности;

По графикам:

  • вывод ревордов не mean_100, а просто за один шаг или сессию;
  • вывод клипнутых ревордов;
  • неправильная реализация moving average и полный бред на графике;
  • прочий рандом, из-за которого сразу не скажешь, что по оси x;

И в чате, и в Anytask студенты жаловались, что нет чёткой инструкции, как выводить график реворда.

Тем, кто делал на pytorch или на TF2, tensorflow_summaries не подходило.

Ответ "... mean 100 ..." студентов не удовлетворяет, им надо функции, которые это выводят. Совет "Бери питон и выводи" (как я делала в своё время) им не подходит. И правильно. Даже если они согласятся это делать сами, многие ошибутся.

Я скинула в чат PR с обновлённым atari_wrappers с функциональностью для вывода реворда без TF.

Некоторые это использовали в домашке, кто-то использовал TF, кто-то сам выводил. Но чат читают не все, а за обновлениями репозитория следит ещё меньше человек. И многие выводили реворд неправильно.

То есть, нельзя даже посмотреть на работу и провести sanity check по графику реворда. Это проблема и для студента, и для проверяющего.

То есть, при проверке, чтобы понять, что код работает (не тот график - всё равно что нет output'a), надо либо добавлять построение графика и запускать код, либо очень долго и внимательно вчитываться во все участки кода. Хорошо, если ошибки есть, тогда их можно найти. Но если их нет, то нет и гарантии, что их действительно нет, а не просто я их не заметила.

Было много работ, где с виду всё правильно, но реворд плохой, и надо было долго вчитываться во всё по нескольку раз и дебажить, чтобы найти, в чём проблема.

И, даже если всё верно, всё равно что-нибудь да есть, т.к. задание большое. Вот у человека всё верно, а он говорит, что-то tensorflow медленно работает, помогите найти почему.

Из этого всего следует, что потратить сильно меньше времени на ActorCritic невозможно.

Это я сейчас пишу для того, чтобы поднять этот вопрос на будущее. Если в следующем году студенты ШАДа будут такие же молодцы и сдадут столько же (или больше) ActorCritic'ов, очень вероятно, что у меня (или у того, кто будет это проверять) не будет столько времени. И придётся понизить качество проверки до неприемлемого уровня.

Как этого избежать:

  • atari_wrappers уже обновлён для вывода ревордов, в следующем году все студенты это увидят. Возможно, надо в тексте ноутбука это как-то лучше описать;
  • Больше ассертов. Добавить ассерт на ComputeValueTargets. Подумать, куда ещё можно добавить ассерты;
  • надо приготовить и описать в ноутбуке примерные диапазоны референсных значений для всего: длина сессий, значения ревордов клипнутые и нет, и так далее;
  • В ноутбуке документировать формат trajectory из runners.py;
  • предлагайте ваши идеи...

В ноутбуке документировать формат trajectory из runners.py;

Когда я сдавал это задание, у меня была вот такая функция

def print_signature(d, indent=0):
    for key, value in d.items():
        print(' ' * indent + f'{key:<20}: ', end='')
        if isinstance(value, list):
            print('list')
            print_signature(value[0], indent=indent+4)
        if isinstance(value, np.ndarray):
            print(f'ndarray    {value.shape!s:<25} {value.dtype!s:<10}')
        elif isinstance(value, torch.Tensor):
            print(f'tensor     {value.shape!s:<25} {value.dtype!s:<10}')
        elif isinstance(value, dict):
            print('dict')
            print_signature(value, indent=indent+4)
        else:
            print(type(value))
            
print_signature(runner.get_next())

Пример вывода:

actions             : ndarray    (5, 8)                    int64     
logits              : tensor     torch.Size([5, 8, 6])     torch.float32
log_probs           : tensor     torch.Size([5, 8, 6])     torch.float32
probs               : tensor     torch.Size([5, 8, 6])     torch.float32
values              : tensor     torch.Size([5, 8])        torch.float32
observations        : ndarray    (5, 8, 84, 84, 4)         uint8     
rewards             : ndarray    (5, 8)                    float64   
resets              : ndarray    (5, 8)                    bool      
state               : dict
    latest_observation  : ndarray    (8, 84, 84, 4)            uint8     
    env_steps           : <class 'int'>
value_targets       : tensor     torch.Size([5, 8])        torch.float32

Related: #181

Сейчас студент прислал пожелания на следующий год:

  • пояснить в преамбуле, что играем 8 параллельных игр по 5 шагов.
  • пояснить картинкой, что является обёрткой над чем (продирался целый день, прежде чем понял что авторы пытались донести всеми прототипами)
  • подправить формулку для ComputeValueTargets
  • добавить прототип get_entropy() в классе A2C
  • неплохо бы сказать, что 160 - это плохо, но если первые 500 000 / 40 итераций не учится, то можно попробовать перезапустить (вообще неочевидно=))