Skeptical, Weakly-Informative, and Strongly-Informative Priors

Labs
Author

Jeff Jacobs

Published

June 28, 2025

import pandas as pd
import numpy as np
rng = np.random.default_rng(seed=5650)
import matplotlib.pyplot as plt
import seaborn as sns
cb_palette = ['#e69f00','#56b4e9','#009e73']
sns.set_palette(cb_palette)
import patchworklib as pw;
import pymc as pm
import arviz as az

[Part 0] Our Simulated Coin Flips

all_flip_results = list(np.random.default_rng(seed=5650).binomial(n=1, p=0.5, size=250))
np.mean(all_flip_results)
0.5

Or truly random flips (from random.org), if preferred

true_random_flips = [
    1,1,0,1,1,1,0,0,1,0,0,0,0,0,1,1,1,0,1,1,1,1,0,1,0,0,0,0,1,1,0,1,1,1,0,0,1,0,
    1,1,1,0,0,0,1,0,0,0,0,1,1,0,1,1,1,0,1,0,1,0,0,1,0,1,1,0,1,0,1,0,1,0,1,1,0,0,
    1,0,0,0,0,1,0,0,0,1,0,1,1,0,0,1,0,0,1,0,1,0,1,1
]
np.mean(true_random_flips)
0.49
one_flip_result = all_flip_results[:1]
one_flip_result
[0]
two_flips_result = all_flip_results[:2]
two_flips_result
[0, 1]
five_flips_result = all_flip_results[:5]
five_flips_result
[0, 1, 0, 0, 1]
ten_flips_result = all_flip_results[:10]
ten_flips_result
[0, 1, 0, 0, 1, 0, 1, 0, 1, 0]

[Part 1] Informative (Beta) Prior Model

with pm.Model() as inf_model:
    result_obs = pm.Data('result_obs', one_flip_result)
    p_heads = pm.Beta("p_heads", alpha=2, beta=2)
    result = pm.Bernoulli("result", p=p_heads, observed=result_obs)
pm.model_to_graphviz(inf_model)

def draw_prior_sample(model, return_idata=False):
    with model:
        prior_idata = pm.sample_prior_predictive(draws=5000, random_seed=5650)
    prior_df = prior_idata.prior.to_dataframe().reset_index().drop(columns='chain')
    if return_idata:
        return prior_idata, prior_df
    return prior_df
inf_n0_df = draw_prior_sample(inf_model)
Sampling: [p_heads, result]
def gen_dist_plot(dist_df, plot_title):
    ax = pw.Brick(figsize=(3.5, 2.25))
    sns.histplot(
        x="p_heads", data=dist_df, ax=ax,
        bins=25
    );
    ax.set_title(plot_title)
    return ax
inf_n0_plot = gen_dist_plot(inf_n0_df, "Beta(1.5, 1.5) Prior on p")
inf_n0_plot.savefig()

def draw_post_sample(model, num_draws=5000):
    with model:
        post_idata = pm.sample(draws=num_draws, random_state=5650)
    post_df = post_idata.posterior.to_dataframe().reset_index().drop(columns='chain')
    return post_df
inf_n1_df = draw_post_sample(inf_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [p_heads]

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.
inf_n1_plot = gen_dist_plot(inf_n1_df, "Posterior (N = 1)")
inf_n1_plot.savefig()

Observe \(N = 2\) Flips

with inf_model:
    pm.set_data({'result_obs': two_flips_result})
inf_n2_df = draw_post_sample(inf_model)    
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [p_heads]

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.
inf_n2_plot = gen_dist_plot(inf_n2_df, "Posterior (N = 2)")
inf_n2_plot.savefig()

def plot_n_dists(n_df):
    ax = pw.Brick(figsize=(5, 3.5));
    sns.kdeplot(
        x="p_heads", hue="n", fill=True, ax=ax, data=n_df,
        common_norm=False
    );
    display(ax.savefig())
inf_n0_df['n'] = 0
inf_n1_df['n'] = 1
inf_n2_df['n'] = 2
inf_3_df = pd.concat([inf_n0_df, inf_n1_df, inf_n2_df])
plot_n_dists(inf_3_df)

Observe \(N = 5\) Flips

with inf_model:
    pm.set_data({'result_obs': five_flips_result})
inf_n5_df = draw_post_sample(inf_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [p_heads]

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.
inf_n5_plot = gen_dist_plot(inf_n5_df, "Posterior (N = 5)")
inf_n5_plot.savefig()

inf_n5_df['n'] = 5
inf_4_df = pd.concat([inf_3_df, inf_n5_df])
plot_n_dists(inf_4_df)

Observe \(N = 10\) Flips

with inf_model:
    pm.set_data({'result_obs': ten_flips_result})
inf_n10_df = draw_post_sample(inf_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [p_heads]

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.
inf_n10_df['n'] = 10
inf_5_df = pd.concat([inf_4_df, inf_n10_df])
plot_n_dists(inf_5_df)

Observe \(N = 25\) Flips

[Part 2] Flat (Uniform) Prior

with pm.Model() as unif_model:
    result_obs = pm.Data('result_obs', one_flip_result)
    p_heads = pm.Beta("p_heads", 1, 1)
    result = pm.Bernoulli("result", p=p_heads, observed=result_obs)
pm.model_to_graphviz(unif_model)

unif_n0_df = draw_prior_sample(unif_model)
Sampling: [p_heads, result]
draw p_heads
0 0 0.970164
1 1 0.085831
2 2 0.101310
3 3 0.315523
4 4 0.285908
... ... ...
4995 4995 0.254631
4996 4996 0.166792
4997 4997 0.772046
4998 4998 0.296384
4999 4999 0.331467

5000 rows × 2 columns

unif_n0_plot = gen_dist_plot(unif_n0_df, "Uniform Prior")
unif_n0_plot.savefig()

Posterior After \(N = 1\) Observed Flips

unif_n1_df = draw_post_sample(unif_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [p_heads]

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.
unif_n1_plot = gen_dist_plot(unif_n1_df, f"Posterior After Observing X = {one_flip_result}")
unif_n1_plot.savefig()

Posterior After \(N = 2\) Flips

with unif_model:
    pm.set_data({'result_obs': two_flips_result})
unif_n2_df = draw_post_sample(unif_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [p_heads]

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.
unif_n2_plot = gen_dist_plot(unif_n2_df, f"Posterior After Observing {two_flips_result}")
unif_n2_plot.savefig()

unif_n0_df['n'] = 0
unif_n1_df['n'] = 1
unif_n2_df['n'] = 2
unif_3_df = pd.concat([unif_n0_df, unif_n1_df, unif_n2_df])
plot_n_dists(unif_3_df)

with unif_model:
    pm.set_data({'result_obs': five_flips_result})
unif_n5_df = draw_post_sample(unif_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [p_heads]

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.
unif_n5_df['n'] = 5
unif_4_df = pd.concat([unif_3_df, unif_n5_df])
plot_n_dists(unif_4_df)

Observe \(N = 10\) Flips

with unif_model:
    pm.set_data({'result_obs': ten_flips_result})
unif_n10_df = draw_post_sample(unif_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [p_heads]

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.
unif_n10_df['n'] = 10
unif_5_df = pd.concat([unif_4_df, unif_n10_df])
plot_n_dists(unif_5_df)

[Part 3] Skeptical (Jeffreys) Prior

This prior is itself derived from an approach to Bayesian statistics called “Objective Bayes”, within which the Jeffreys Prior for the Bernoulli parameter \(p\) has a special status.

For our purposes, however, we can just view it as a “skeptical” prior: it encodes an assumption that the coin is very biased, i.e., that before seeing any actual coin flips we think that \(p = 0\) and \(p = 1\) are more likely than any of the values in between (any of the values \(p \in (0, 1)\)). This means that—relative to the Beta and Uniform cases—someone with these priors would require a very “even” mixture of heads and tails to “cancel out” their pre-existing belief that the coin is biased!

with pm.Model() as flat_model:
    result_obs = pm.Data('result_obs', one_flip_result)
    p_heads = pm.Beta("p_heads", 0.5, 0.5)
    result = pm.Bernoulli("result", p=p_heads, observed=result_obs)
pm.model_to_graphviz(flat_model)

flat_n0_df = draw_prior_sample(flat_model)
Sampling: [p_heads, result]
flat_n0_plot = gen_dist_plot(flat_n0_df, "Flat(-ish) Prior")
flat_n0_plot.savefig()

Observe \(N = 1\) Flip

flat_n1_df = draw_post_sample(flat_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [p_heads]

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.
flat_n1_plot = gen_dist_plot(flat_n1_df, f"Posterior After Observing {one_flip_result}")
flat_n1_plot.savefig()

Observe \(N = 2\) Flips

with flat_model:
    pm.set_data({'result_obs': two_flips_result})
pm.model_to_graphviz(flat_model)

flat_n2_df = draw_post_sample(flat_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [p_heads]

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.
flat_n2_plot = gen_dist_plot(flat_n2_df, "Posterior")
flat_n2_plot.savefig()

flat_n0_df['n'] = 0
flat_n1_df['n'] = 1
flat_n2_df['n'] = 2
flat_3_df = pd.concat([flat_n0_df, flat_n1_df, flat_n2_df])
plot_n_dists(flat_3_df)

with flat_model:
    pm.set_data({'result_obs': five_flips_result})
flat_n5_df = draw_post_sample(flat_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [p_heads]

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.
flat_n5_df['n'] = 5
flat_4_df = pd.concat([flat_3_df, flat_n5_df])
plot_n_dists(flat_4_df)

Observe \(N = 10\) Flips

with flat_model:
    pm.set_data({'result_obs': ten_flips_result})
flat_n10_df = draw_post_sample(flat_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [p_heads]

Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 3 seconds.
flat_n10_df['n'] = 10
flat_5_df = pd.concat([flat_4_df, flat_n10_df])
plot_n_dists(flat_5_df)

[Part 4] Which One Learned Most Efficiently?

After \(N = 2\)?

def plot_dist_comparison(combined_df):
    ax = pw.Brick(figsize=(5, 3.5));
    sns.kdeplot(
        x="p_heads", hue="prior", fill=True, ax=ax, data=combined_df,
        common_norm=False
    );
    display(ax.savefig())
inf_n2_df['prior'] = 'Beta(2, 2)'
unif_n2_df['prior'] = 'Beta(1, 1)'
flat_n2_df['prior'] = 'Beta(0.5, 0.5)'
all_n2_df = pd.concat([inf_n2_df, unif_n2_df, flat_n2_df])
plot_dist_comparison(all_n2_df)

After \(N = 5\)?

inf_n5_df['prior'] = 'Beta(2, 2)'
unif_n5_df['prior'] = 'Beta(1, 1)'
flat_n5_df['prior'] = 'Beta(0.5, 0.5)'
all_n5_df = pd.concat([inf_n5_df, unif_n5_df, flat_n5_df])
plot_dist_comparison(all_n5_df)

After \(N = 10\)?

inf_n10_df['prior'] = 'Beta(2, 2)'
unif_n10_df['prior'] = 'Beta(1, 1)'
flat_n10_df['prior'] = 'Beta(0.5, 0.5)'
all_n10_df = pd.concat([inf_n10_df, unif_n10_df, flat_n10_df])
plot_dist_comparison(all_n10_df)

with inf_model:
    print(pm.find_MAP())

{'p_heads_logodds__': array(-0.33647223), 'p_heads': array(0.41666667)}
with unif_model:
    print(pm.find_MAP())

{'p_heads_logodds__': array(-0.4054651), 'p_heads': array(0.4)}
with flat_model:
    print(pm.find_MAP())

{'p_heads_logodds__': array(-0.47956846), 'p_heads': array(0.38235403)}