"""
 flips.py

 See https://cs.marlboro.college/cours/spring2020/data/notes/feb18

 The coin flip is https://cs.marlboro.college/tools/coins/one.cgi .

 Is this flip fair ? (It gives a random 't' or 'h' with each page load.) 
 Make and discuss an explicit hypothesis test to decide, 
 in two cases : with 10 coin flips, and with 5000 coin flips. 

    $ python3 flips.py 
    --- testing H0=fair_coin with n = 10 ---
    Simulating the null hypothesis ...
      decision cutoffs: 1.6732065515668273, 8.288793448433172
    flipping the coin ...
    100%|███████████████████████████████████████████████████| 10/10 [00:00<00:00, 15.93it/s]
    Number of flips was 5
    Fail to reject the null hypotheis : cannot say that the coin is not fair

    --- testing H0=fair_coin with n = 5000 ---
    Simulating the null hypothesis ...
      decision cutoffs: 2430.806163623822, 2571.5858363761777
    flipping the coin ...
    100%|███████████████████████████████████████████████████| 5000/5000 [04:55<00:00, 16.90it/s]
    Number of flips was 2378
    Reject the null hypothesis : the coin is not fair

 Jim Mahoney | Feb 2020 | cs.marlboro.college | MIT License
"""
import urllib.request
import tqdm
from math import sqrt
from random import randint

def mean(xs):
    """ average of a collection of values """
    return sum(xs)/len(xs)

def standard_deviation(xs):
    """ best estimate of parent populartion's standard deviation """
    mu = mean(xs)
    n = len(xs)
    offsets_squared = [(x - mu)**2 for x in xs]
    return sqrt( mean(offsets_squared) * (n/(n-1)) )

def count_heads(n):
    """ Return number of heads in n flips """
    url='https://cs.marlboro.college/tools/coins/one.cgi'
    heads = 0
    for i in tqdm.tqdm(range(n)):
        if urllib.request.urlopen(url).read() == b'h\n':
            heads += 1
    return heads

def cutoffs_from_simulation(n, n_sims=1000):
    """ Return inference cutoff values for p=0.5 binomial with n trials,
        at the 2 sigma, p-value = 0.05 confidence level,
        i.e. (mean - 2 * sigma, mean + 2 * sigma) """
    sim_flips = [sum(randint(0,1) for i in range(n)) for _ in range(n_sims)]
    mu = mean(sim_flips)
    sigma = standard_deviation(sim_flips)
    return (mu - 2*sigma, mu + 2*sigma)

def main():
    for n in (10, 5000):
        print("--- testing H0=fair_coin with n = {} ---".format(n))
        print("Simulating the null hypothesis ...")
        (low, high) = cutoffs_from_simulation(n)
        print("  decision cutoffs: {}, {}".format(low, high))
        print("flipping the coin ...")
        flips = count_heads(n)
        print("Number of flips was {}".format(flips))
        if flips < low or flips > high:
            print("Reject the null hypothesis : the coin is not fair")
        else:
            print("Fail to reject the null hypotheis : cannot say that the coin is not fair")
        print()

main()