본문 바로가기
오픈소스2026년 5월 9일4분 읽기

PyTorch 3.0 정식 발표 — JAX 핵심 기능 흡수, "1년 안 JAX 의미 없어진다" 평가

YS
김영삼
조회 1433
PyTorch 3.0 정식 발표 — JAX 핵심 기능 흡수, "1년 안 JAX 의미 없어진다" 평가

핵심 요약

Meta가 주도하는 ML 프레임워크 PyTorch 3.0이 5월 9일 정식 발표됐다. 가장 큰 변화는 Google JAX의 핵심 함수형 변환(vmap, pmap, jit, grad)을 PyTorch 네이티브로 도입한 것. 업계는 "PyTorch가 사실상 ML 프레임워크 시장을 통일했다"고 평가한다.

  • 버전: PyTorch 3.0.0
  • 핵심 추가: torch.func 모듈 정식화 (이전 functorch 후속)
  • 성능: torch.compile 안정화, NVIDIA H200/B100에 최적화
  • 라이선스: BSD 그대로

JAX 기능 흡수 — 무엇이 바뀌나

JAX 함수PyTorch 3.0 대응의미
jax.vmaptorch.vmap배치 차원 자동 처리
jax.pmaptorch.pmap멀티 디바이스 자동 분산
jax.jittorch.compile (안정화)JIT 컴파일
jax.gradtorch.func.grad함수형 미분
jaxprFX Graph (확장)중간 표현

왜 지금

2024~2025년 Google이 Gemini 학습 인프라를 JAX 기반으로 표준화하면서, JAX 사용자가 빠르게 늘었다(특히 학계). PyTorch 측은 "함수형 변환이 PyTorch의 약점"이라는 비판을 인정하고 2년간 통합 작업을 추진했다고 공식 블로그에서 밝혔다.

코드 예시 — 익숙한 PyTorch 문법으로 JAX의 강력함

import torch
from torch.func import vmap, grad, jit

def loss(params, x, y):
    pred = params['w'] @ x + params['b']
    return ((pred - y) ** 2).mean()

# 단일 샘플 미분
grad_fn = grad(loss)
# 배치로 자동 확장 (JAX의 vmap 그대로)
batched_grad = vmap(grad_fn, in_dims=(None, 0, 0))
# JIT 컴파일
compiled = jit(batched_grad)

실측 — Llama 3 학습 32 GPU 비교

지표PyTorch 2.5PyTorch 3.0JAX 0.6
throughput (samples/sec)1,8202,1402,180
메모리 (per GPU)74GB62GB59GB
코드 변경 vs PyTorch 2거의 없음완전 재작성

JAX의 미래

한 Google Brain 출신 연구자는 "JAX의 유일한 사용 이유였던 함수형 변환을 PyTorch가 가져갔다. 1년 안에 JAX는 의미를 잃을 것"이라고 평가했다. Google 내부에서도 Gemini 3 이후의 학습 코드가 PyTorch로 옮겨가는 분위기라는 소문이다.

한국 영향

  • 네이버: HyperCLOVA X 학습 코드 PyTorch 기반 — 3.0 적용 시 학습 비용 12~18% 절감 예상
  • 카카오 Brain: 일부 JAX 사용 팀이 PyTorch로 통합 검토 중
  • 대학 연구실: JAX에 투자한 박사과정 학생 일부 당황 — "내 논문 코드를 다시 짜야 하나"

커뮤니티 반응

Hacker News에서 1,400개 이상 댓글이 달렸고, "PyTorch wins" vs "JAX는 단순히 ML 도구가 아닌 함수형 사고방식" 논쟁이 격렬했다. PyTorch의 한 핵심 메인테이너는 트위터에 "우리는 JAX를 죽이려는 게 아니다. 그 좋은 아이디어를 더 많은 사람이 쓸 수 있게 만든 것일 뿐"이라고 적었다.

전문가 코멘트

김기현 카이스트 AI대학원 교수는 "프레임워크 단일화는 연구자에겐 좋은 소식. PyTorch와 JAX 양쪽 코드를 유지하던 회사들이 통합할 수 있게 됐다"며 "다만 단일 의존도가 높아지는 만큼, Meta의 PyTorch 거버넌스 투명성이 더 중요해진다"고 지적했다.

댓글 0

아직 댓글이 없습니다.
Ctrl+Enter로 등록