핵심 요약
Meta가 주도하는 ML 프레임워크 PyTorch 3.0이 5월 9일 정식 발표됐다. 가장 큰 변화는 Google JAX의 핵심 함수형 변환(vmap, pmap, jit, grad)을 PyTorch 네이티브로 도입한 것. 업계는 "PyTorch가 사실상 ML 프레임워크 시장을 통일했다"고 평가한다.
- 버전: PyTorch 3.0.0
- 핵심 추가:
torch.func모듈 정식화 (이전 functorch 후속) - 성능: torch.compile 안정화, NVIDIA H200/B100에 최적화
- 라이선스: BSD 그대로
JAX 기능 흡수 — 무엇이 바뀌나
| JAX 함수 | PyTorch 3.0 대응 | 의미 |
|---|---|---|
| jax.vmap | torch.vmap | 배치 차원 자동 처리 |
| jax.pmap | torch.pmap | 멀티 디바이스 자동 분산 |
| jax.jit | torch.compile (안정화) | JIT 컴파일 |
| jax.grad | torch.func.grad | 함수형 미분 |
| jaxpr | FX Graph (확장) | 중간 표현 |
왜 지금
2024~2025년 Google이 Gemini 학습 인프라를 JAX 기반으로 표준화하면서, JAX 사용자가 빠르게 늘었다(특히 학계). PyTorch 측은 "함수형 변환이 PyTorch의 약점"이라는 비판을 인정하고 2년간 통합 작업을 추진했다고 공식 블로그에서 밝혔다.
코드 예시 — 익숙한 PyTorch 문법으로 JAX의 강력함
import torch
from torch.func import vmap, grad, jit
def loss(params, x, y):
pred = params['w'] @ x + params['b']
return ((pred - y) ** 2).mean()
# 단일 샘플 미분
grad_fn = grad(loss)
# 배치로 자동 확장 (JAX의 vmap 그대로)
batched_grad = vmap(grad_fn, in_dims=(None, 0, 0))
# JIT 컴파일
compiled = jit(batched_grad)
실측 — Llama 3 학습 32 GPU 비교
| 지표 | PyTorch 2.5 | PyTorch 3.0 | JAX 0.6 |
|---|---|---|---|
| throughput (samples/sec) | 1,820 | 2,140 | 2,180 |
| 메모리 (per GPU) | 74GB | 62GB | 59GB |
| 코드 변경 vs PyTorch 2 | — | 거의 없음 | 완전 재작성 |
JAX의 미래
한 Google Brain 출신 연구자는 "JAX의 유일한 사용 이유였던 함수형 변환을 PyTorch가 가져갔다. 1년 안에 JAX는 의미를 잃을 것"이라고 평가했다. Google 내부에서도 Gemini 3 이후의 학습 코드가 PyTorch로 옮겨가는 분위기라는 소문이다.
한국 영향
- 네이버: HyperCLOVA X 학습 코드 PyTorch 기반 — 3.0 적용 시 학습 비용 12~18% 절감 예상
- 카카오 Brain: 일부 JAX 사용 팀이 PyTorch로 통합 검토 중
- 대학 연구실: JAX에 투자한 박사과정 학생 일부 당황 — "내 논문 코드를 다시 짜야 하나"
커뮤니티 반응
Hacker News에서 1,400개 이상 댓글이 달렸고, "PyTorch wins" vs "JAX는 단순히 ML 도구가 아닌 함수형 사고방식" 논쟁이 격렬했다. PyTorch의 한 핵심 메인테이너는 트위터에 "우리는 JAX를 죽이려는 게 아니다. 그 좋은 아이디어를 더 많은 사람이 쓸 수 있게 만든 것일 뿐"이라고 적었다.
전문가 코멘트
김기현 카이스트 AI대학원 교수는 "프레임워크 단일화는 연구자에겐 좋은 소식. PyTorch와 JAX 양쪽 코드를 유지하던 회사들이 통합할 수 있게 됐다"며 "다만 단일 의존도가 높아지는 만큼, Meta의 PyTorch 거버넌스 투명성이 더 중요해진다"고 지적했다.

댓글 0