Обучаем LLM на GPU-кластерах
Evgeniy NikitinНаткнулся на очень крутую вещь - The Ultra-Scale Playbook: Training LLMs on GPU Clusters. Если вам интересны темы оптимизации обучения нейронок, внутренностей Пайторча, параллелизации обучения LLM - рекомендую прочитать целиком. Здесь я подготовил краткую выжимку содержания для тех, кому не хочется читать весь длинный текст (и то она получилась немаленькой). Кратко опишу пять типов параллелизма (data, tensor, sequence/context, pipeline, expert), расскажу про ключевые трюки экономии видеопамяти (ZeRO-1/2/3, чекпойнтинг и акуммуляция градиентов, FlashAttention). Но за подробностями, чтоб лучше понять - читайте исходную статью.
Дополнительный плюс - в их тексте много ссылок на другие классные статьи, например:
- Как работает аллокация кэша в Пайторче и почему иногда помогает torch.cuda.memory.empty_cache()
- Почему смена GELU на ReLU может внезапно срезать потребление видеопамяти на четверть
- FlexAttention - обёртка FlashAttention в PyTorch, которая позволяет пробовать альтернативные варианты этеншна (например, SoftCap, Document Masking или ALiBi), не теряя преимуществ FlashAttention по скорости и потреблению памяти
- Оптимизация CUDA-кернелов - позволяет мощно вкатиться в то, как устроены и работают GPU
Чтиво некороткое и непростое, но даже поверхностное ознакомление неплохо расширяет кругозор. А если уж выборочно поразбираться и по ссылкам походить - вообще топ. Мне кажется, тема полезная, даже если вы не планируете сами обучать LLM на огромных кластерах - в целом улучшает понимание процесса обучения нейронок, особенно больших языковых моделей.
Параллелизуем обучение
При обучении LLM (хотя в целом это относится ко всем большим сеткам) возникает ряд проблем:
- Недостаток памяти - большие модели не вмещаются в одну GPU, а иногда и в одну машину с несколькими GPU
- Вычисления - обучение LLM дело дорогое, хочется делать это быстрее и эффективнее
- Коммуникация - обмен данными между GPU на одной машине проходит достаточно быстро, особенно с NVLink а вот обмен по сети может стать боттлнеком
Чаще всего нам хочется сократить потребление памяти, не сильно потеряв (или даже выиграв) в скорости вычислений и затратах на коммуникацию. Как это сделать?
Память у нас уходит на:
- Веса модели
- Градиенты, которые рассчитываются во время бэкворд пасса
- Стейт оптимайзера - если используем Adam или другой оптимайзер, хранящий какое-то состояние
- Активации, которые рассчитываются во время форвард пасса, их размер напрямую зависит от размера инпута (например, высота и ширина картинки или длина текста)
- Прочее - кернелы Пайторча, буферы, промежуточные результаты

Две старые техники (постоянно использовались и до LLM) позволяют уменьшить потребление памяти и впихнуть обучение даже на одну карту:
- Градиент-чекпойнтинг - сокращает потребление памяти за счёт добавления лишних вычислений. Во время форвард-пасса дропаем все или часть активаций, и перерасчитываем их во время бэкворд-пасса.
- Аккумуляция градиентов - разбиваем желаемый батч на микро-батчи (например, если наш желаемый батч-сайз - 32, а на карту помещается 4 сэмпла, то делаем 8 микро-батчей), не делаем обновление весов после каждого микро-батча, а аккумулируем вычисленные градиенты. После 8 форвард-бэквардов обновляем веса.
Основная идея аккумуляции градиентов отлично работает, если у нас на машине есть несколько видеокарт - мы можем параллелизовать наши forward/backward-пассы. Такой подход называется data parallelism. К сожалению, в этом случае мы сильно просаживаемся по эффективности использования GPU. Нам нужно дождаться, пока завершится бэкворд-пасс на всех GPU, собрать со всех градиенты и рассчитать текущий апдейт весов. Как реализовать data parallelism эффективнее?
- Начинать синхронизацию градиентов во время бэкворд-пасса. Как только мы посчитали градиентов последних слоёв, уже можно начинать собирать и суммировать градиенты с разных карт. Ещё лучше сгруппировать градиенты (например, по слоям) и синхронизировать их группами.
- Если мы хотим достичь очень большого батч-сайза (что является стандартным кейсом при обучении LLM), то у нас может не хватить GPU. В таком случае можно комбинировать data parallelism и аккумуляцию градиентов. В таком случае синхронизация градиентов требуется только последнего шага аккумуляции, поскольку обновление весов будет производиться только после него.
Шардирование
Это всё круто работает внутри одной машины с несколькими GPU, где обмен данными между GPU очень быстр - особенно через NVLink. Как только речь заходит о больших кластерах с десятками, сотнями и тысячами видеокарт - обмениваться градиентами приходится через сеть, что скейлится не очень хорошо. А ещё для некоторых особо больших моделей невозможно достичь даже батч-сайза размером 1 на одной видеокарте. Что делать?
Оказывается, можно уменьшить нагрузку на видеопамяти каждой карты, потому что нам необязательно держать полную копию весов, состояний оптимайзера и градиентов на каждой карте.
- Самое простое (ZeRO-1) - побить состояния оптимайзера по картам, рассчитывать апдейт для соответствующей части весов на каждой карте, а потом передавать этот апдейт на все карты, чтобы на каждой были актуальные веса модели.
- Раз мы рассчитываем на каждой карте апдейт только для части весов, то и градиенты нам нужны только для этой части (ZeRO-2). Остальные можно передать на нужную карту и дропнуть.
- Наконец, можно разделить и веса модели (ZeRO-3). Из минусов - во время форвард-пасса нам придётся на лету собирать и дропать веса текущего слоя. Частично это решается тем, что сбор весов для слоя N+1 мы можем делать во время вычислений слоя N.
Таким образом, с помощью ZeRO мы снижаем пиковое потребление памяти, но немного увеличиваем трафик между видеокартами и машинами.
Подход ZeRO не касается четвёртого, часто наиболее жирного, компонента - активаций. Особенно это актуально для длинных последовательностей токенов. Для решения этой проблемы есть своя техника - tensor parallelism.
Tensor и sequence parallelism

Перемножение двух огромных тензоров - чаще всего инпут-тензона (X) и тензора весов (W), необязательно делать на одной карточке. Мы можем разделить тот или иной тензор на кусочки:
- Можно разделить тензор весов на отдельные колонки (или группы колонок), произвести перемножение с каждой колонкой на отдельной карте, а потом слепить итоговый тензор.
- Либо можно разделить входной тензор на колонки, а тензор весов на строчки, произвести перемножение, а затем суммировать результаты с каждой карты.
Пример: матрица W (4096 × 11008) в FFN-слое. При TP = 4 распределяем столбцы W по 4 картам (по 2768 столбцов). Каждая GPU считает X * W_i, потом выполняется all-reduce, то есть,собираем полный выход. Либо делим по строкам, и выполняем all-gather.
Мы можем подобрать более эффективное разделение или их сочетание для каждой операции трансформер-слоёв. Например, для multihead-attention каждая GPU может производить вычисления для отдельной головы этеншна или группы голов этеншна.
При этом некоторые операции (например, LayerNorm) требуют сбора всех активаций на каждой карточке. Поэтому нам нужен ещё и sequence parallelism.

Хотя для таких операций нам и нужные полные активации, зато они, в отличие от этеншна, не зависят от других токенов. Это позволяет разделить последовательность на куски и рассчитать результат для каждого куска на отдельной карточке.
Важно отметить, что tensor и sequence parallelism добавляют коммуникацию между GPU, которая не может производиться внахлёст в вычислениями, поскольку нам нужно дождаться результатов со всех карт. Поэтому такая параллелизация обычно используется только внутри одной машины, а не по сети. К тому же, для очень длинных последовательностей активации в частях сетки, где используется tensor parallelism, всё равно будут занимать очень много памяти.
Context parallelism
Что же нам тогда делать с этеншн-блоками и длинными последовательностями? На помощь приходит context parallelism, а точнее его эффективная имплементация Ring Attention.
- Объединяем карточки (машины) в кольцо, у каждой карты будет предыдущий и следующий узлы
- Делим последовательность на куски, каждя карточка будет содержать свою часть Q_i, K_i, V_i
- Считаем на карточке этеншн-скоры для текущих частей keys и values
- Параллельно отправляем текущие keys и values на следующую карточку (машину)
Повторяем эти операции, пока не посчитали все этеншн-скоры.
Есть более эффективная форма Ring Attention - Zig-Zag Ring Attention, которая делит последовательность не на подряд текущие токены, а более равномерно (например, на первую GPU попадают вычисления первых и последних токенов). Благодаря этому достигается более равномерная вычислительная нагрузка на карточки.
Pipeline parallelism
Наконец, можно разделить жирную модель на куски по слоям и отправить каждый кусок модели на свою карточку. В этом случае нам нужно передавать между карточками активации от предыдущего слоя к следующему. В наивной имплементации есть огромная проблема - мы теряем кучу GPU-времени, на картинке это время простоя отображается серым цветом.
Есть различные способы уменьшения этого пузыря:
- All forward, all backward - разделить батчи на более маленькие, сначала делаем форвард по всем микро-батчам, потом начинаем бэкворд
- One forward, one backward - часть времени мы постоянно перескакиваем с бэкворда на одном микро-батче и форварда на другом
- И другие ещё более сложные схемы
Expert parallelism

Наконец, последний вид параллелизма - это всем известный Mixture-of-Experts. В данном случае мы можем поместить веса каждого "эксперта" на отдельную карту, поскольку их вычисления не зависят друг от друга. Остаётся лишь написать логику, которая отправляет наши тензоры на нужного эксперта. Часто эта форма параллелизма комбинируется с другими, например, expert + data.
Оптимизируем вычисления
Вторая (точнее третья, я здесь пропустил часть про их эксперименты по сравнению разных конфигураций параллелизма, включая комбинации разных типов) часть описывает различные оптимизации, которые учитывают специфику работы GPU.
Параллелизм - это хорошо, но мы бы хотели добиться и оптимальных вычислений на конкретной GPU. Для этого можно писать кастомные кернелы.
Есть несколько способов уровней ускорения вашего GPU-кода через кернелы:
- Ванильный PyTorch - быстро, просто, медленно
- Декоратор @torch.compile - автоматическая генерация высокопроизводительных кернелов по вашему PyTorch-коду
- Можно взять Тритон-кернел, сгенерированный @torch.compile и попробовать улучшить его руками
- Наконец, можно написать кернел прямо на CUDA
Какие есть техники оптимизации через кернелы?
- Memory coalescing - ускорение доступа ядер к глобальной GPU-памяти через объединение запросов к соседним локациям в памяти.
- Tiling - если вычисления организованы так, что треды одного блока обращаются к одним и тем же данным, то можно загружать кусочки матриц в быструю память, чтобы все треды производили на них вычисления, а потом загружать следующий кусочек и так далее. Результаты аккумулируются в промежуточной матрице.
- Minimizing control divergence - пишем кернелы так, чтобы все треды в ворпе (32 треда) использовали одну и ту же вычислительную инструкцию. Инструкции могут быть разные если, например, у нас в коде есть бранчевание через if.
- Fusing kernels - объединение нескольких операций в одну для уменьшения коммуникации с долгой памятью. Особенно актуально для pointwise-операций, которые не зависят от других частей инпута (например, функции активации).
FlashAttention
При вычислении этеншн-скоров у нас появляется две больших матрицы (особенно для длинных последовательностей) - матрица этеншн-скоров и матрица нормализованных вероятностей после софтмакса. Соответственно нам нужно вычислить всю матрицу скоров, передать её в долгую память, а затем отправить всё обратно на вычисление софтмаксовых вероятностей.
Вместо этого мы можем не материализовывать целые матрицы скоров, а вычислять их кусочками. Делим матрицы Q, K, V на блоки по сколько-то токенов. Далее идём по чанкам Q, и для каждого считаем этеншн-скоры по каждому чанку K. По ходу аккумулируем статистики, нужные для итоговой софтмакс-нормализации. Таким образом, мы снижаем пиковое потребление видеопамяти и уменьшаем количество перегонов туда-сюда.