Este guia explica como usar o back-end do JAX no Meridian.
Introdução ao back-end do JAX
Por padrão, o Meridian usa o TensorFlow para as principais operações numéricas e a amostragem de Monte Carlo via cadeias de Markov (MCMC, na sigla em inglês) probabilística, oferecendo uma base robusta e totalmente testada para todas as tarefas de modelagem.
Para projetos que se beneficiam de desempenho e eficiência de memória aprimorados, o Meridian oferece o back-end do JAX. O JAX incentiva um estilo de programação funcional e usa a compilação de álgebra linear acelerada (XLA, na sigla em inglês) para oferecer otimizações avançadas de desempenho.
Tutorial: para ver o JAX em ação, consulte o notebook Introdução ao JAX.
Como ativar o JAX
As bibliotecas matemáticas principais são carregadas na inicialização. Por isso, você precisa instruir o Meridian a usar o JAX antes de importar qualquer módulo do Meridian.
Para ativar o JAX, defina a variável de ambiente MERIDIAN_BACKEND como 'jax'.
Defina essa variável de ambiente no script antes de executar qualquer instrução import meridian:
import os
# Enable JAX before importing Meridian
os.environ['MERIDIAN_BACKEND'] = 'jax'
# Now it is safe to import Meridian modules
from meridian.model import model
from meridian.data import load
Ativar a precisão de 64 bits
Para modelos em que a convergência é difícil de alcançar, usar precisão de 64 bits com o back-end do JAX pode melhorar a estabilidade numérica. Apesar disso, essa precisão aumenta o uso de memória e diminui os tempos de computação. Portanto, a precisão de 32 bits continua sendo o padrão para a maioria dos casos de uso. Para ativá-la, defina a variável de ambiente MERIDIAN_ENABLE_JAX_X64 como 'True' antes de importar o Meridian.
import os
# Enable JAX 64-bit precision
os.environ['MERIDIAN_ENABLE_JAX_X64'] = 'True'
# Enable JAX backend
os.environ['MERIDIAN_BACKEND'] = 'jax'
# Now it is safe to import Meridian modules
from meridian.model import model
Se uma string inválida for fornecida à variável de ambiente MERIDIAN_BACKEND, o Meridian vai emitir um aviso de tempo de execução e voltar à execução padrão do TensorFlow. Se um valor diferente de "True" ou "1" for fornecido à variável de ambiente MERIDIAN_ENABLE_JAX_X64, a precisão de 64 bits não será ativada, e o Meridian vai usar a precisão de 32 bits por padrão.
Diferenças de API ao usar JAX em vez de TensorFlow
Ao fazer a transição do TensorFlow para o back-end JAX, há diferenças importantes na API que você precisa ajustar no seu código:
Distribuições a priori
Se você estiver definindo distribuições a priori personalizadas com PriorDistribution, utilize tfp.substrates.jax.distributions em vez de tfp.distributions. Exemplo:
TensorFlow
import tensorflow_probability as tfp
from meridian.model import prior_distribution
prior = prior_distribution.PriorDistribution(
roi_m=tfp.distributions.LogNormal(0.2, 0.9)
)
JAX
import tensorflow_probability as tfp
tfp_jax = tfp.substrates.jax
from meridian.model import prior_distribution
prior = prior_distribution.PriorDistribution(
roi_m=tfp_jax.distributions.LogNormal(0.2, 0.9)
)
Requisito de semente explícita
Ao usar o back-end JAX, é necessário ter uma semente explícita para funções estocásticas (por exemplo, em sample_posterior()). Enquanto o TensorFlow usa um gerador global de números aleatórios que escolhe automaticamente uma semente, o JAX torna essa semente explícita. Não encontramos diferenças estatisticamente significativas nas estimativas de ROI ou nas mudanças de orçamento em diferentes sementes..
# Explicitly set a seed for MCMC sampling when using the JAX backend
mmm.sample_posterior(
n_chains=2,
n_adapt=1000,
n_burnin=500,
n_keep=1000,
seed=0,
)
Para mais detalhes sobre números aleatórios e sementes do JAX, consulte a documentação de números pseudoaleatórios do JAX.
Diferenças numéricas e reprodutibilidade
Como o TensorFlow e o JAX compilam os grafos computacionais de maneira diferente, você pode observar pequenas diferenças numéricas nas estimativas a posteriori ao mudar para o JAX usando os mesmos dados e sementes aleatórias.
Embora as distribuições a posteriori não sejam idênticas em todos os back-ends, as diferenças geralmente são pequenas e não estatisticamente significativas para métricas de negócios como ROI e alocação de orçamento. Isso garante que a troca para o back-end do JAX mantenha a integridade dos insights do modelo.
Considerações sobre desempenho
Testes internos descobriram que o JAX impulsionou as execuções iniciais do modelo, reduzindo o tempo de execução médio em cerca de 40% e o uso da memória em cerca de 70%, em comparação com o TensorFlow ao usar GPUs. O JAX também simplificou as iterações do modelo, permitindo tempos de execução 2 vezes mais rápidos, uso de memória 4 vezes menor e fluxos de trabalho ininterruptos ao eliminar a necessidade de reinicializações do kernel.
Devido ao aumento da eficiência da memória, você tem mais espaço para ajustar parâmetros computacionalmente intensivos. Por exemplo, em Meridian.sample_posterior(), você pode aumentar o argumento unrolled_leapfrog_steps (por exemplo, de 1 para 5). Isso pode acelerar a convergência aumentando o comprimento da trajetória do No U Turn Sampler (NUTS) sem exceder os limites de memória do hardware. Também é possível aumentar o parâmetro n_adapt para ajudar ainda mais na convergência durante a fase de adaptação.