"""
 analyze.py

 compute conditional probabilities

   $ python3 analyze.py 
   filename = 'short1.txt'
   n words =  339
   first few words:  ['for', 'two', 'days', 'and']
   first few pairs:  [('for', 'two'), ('two', 'days'), ('days', 'and'), ('and', 'nights')]

   n unique words =  176
   n unique pairs =  308

   sum of probabilities is  1.0000000000000013
      word            probability  
      ------------    ------------ 
      and             0.0708
      the             0.0649
      he              0.0354
      was             0.0295
      his             0.0236
      a               0.0236
      him             0.0236
      that            0.0236
      of              0.0206
      for             0.0206

   look at some conditional probabilities
    ---- given first word "and"
     sum of given1_prob2[and][w] is  0.9999999999999996
     next likely nights(0.1667) he(0.0833) his(0.0417) 
    ---- given first word "the"
     sum of given1_prob2[the][w] is  0.9999999999999996
     next likely man(0.0909) bars(0.0909) express(0.0909) 
    ---- given first word "he"
     sum of given1_prob2[he][w] is  1.0
     next likely was(0.2500) knew(0.0833) accumulated(0.0833) 

   check that P(x) = sum over y of P(x|y) * P(y)
     probability P("he") is 0.03540 
     sum of is P("he"|y)*P(y) is 0.03540

   check Bayes theorem P(y|x) = P(x|y)*P(y)/P(x) 
   or P(x|y)*P(y) * P(y|x)*P(x)
   Choose  x="he" y="was"
    P(x) =  0.035398230088495575
    P(y) =  0.029498525073746312
    P(y|x) =  0.25 = (count (he, was)) / (count (he, _))
    P(x|y) =  0.3 = (count (he, was)) / (count (_, was))
    P(x) * P(y|x) =  0.008849557522123894
    P(y) * P(x|y) =  0.008849557522123894
    P(x & y) =  0.008875739644970414

   So there. Mnnnh.

 Jim Mahoney | Feb 2020 | cs.marlboro.college | MIT License
"""

def get_words(filename):
    """ Return a list of the words in the file """
    # adapted from my Nov 2012 count_words.py which analyzed moby dick.
    ignore = ( ',', '.', '"', "'", ';', ':', '!', '(', ')', '*', '$' )
    input = open(filename, 'r')
    result = []
    while True:
        line = input.readline()                 #   Read in a line.
        if not line:                            #   Stop if no more text to read.
            break
        words = line.split()
        for word in words:
            word = word.lower()                 #     Convert to lowercase
            for char in ignore:                 #     Remove puncutation
                word = word.replace(char, '')
            result.append(word)
    return result

def conditionals(pairs):
    """ Given pairs=[(1st_word_1, 2nd_word_1), ...
        return probabiity of 2nd_word given 1st_word
        as given1_prob2[one][two] = p(two|one) """
    # Note that sum(given1_prob2[one][two] for y in seconds) should be 1 for any x.
    seconds = tuple(set(two for (one, two) in pairs))
    firsts = tuple(set(one for (one, two) in pairs))
    count_1_2 = {word1:{word2:0 for word2 in seconds} for word1 in firsts} # count[1][2]
    for (one, two) in pairs:
        count_1_2[one][two] += 1
    given1_prob2 = {}
    for first in firsts:
        given1_total = sum(count_1_2[first][second] for second in seconds)
        given1_prob2[first] = {second: count_1_2[first][second]/given1_total
                               for second in seconds}
    return given1_prob2

def probabilities(words):
    """ Return dict p[word] = probability """
    p = {}
    for word in words:
        p[word] = 1 + p.get(word, 0)
    for word in p:
        p[word] = p[word]/len(words)
    return p

def main():
    filename = 'short1.txt'  # excerpt from "the call .."
    # filename = 'call_of_the_wild.txt' ... dies with 'killed'; too much memory
    
    words = get_words(filename)
    pairs = list(zip(words, words[1:]))

    print('filename = ', filename)
    print('n words = ', len(words))
    print('first few words: ', words[:4])
    print('first few pairs: ', pairs[:4])
    print()

    print('n unique words = ', len(set(words)))
    print('n unique pairs = ', len(set(pairs)))
    print()

    p = probabilities(words)
    words_by_freq = sorted(p.keys(), key=lambda w:p[w], reverse=True)
    print('sum of probabilities is ', sum(p.values()))
    print('   {:12s}    {:12s} '.format('word', 'probability'))
    print('   {:12s}    {:12s} '.format('-'*12, '-'*12))
    for i in range(10):
        print('   {:12s}    {:0.4f}'.format(words_by_freq[i], p[words_by_freq[i]]))
        # most to least common : and, the, he, was, a, his, that, him, for, ...
    print()

    print('look at some conditional probabilities')
    # for example: after 'he', 'was' is the most common.
    given1_prob2 = conditionals(pairs)
    for given in ('and', 'the', 'he'):
        print(' ---- given first word "{}"'.format(given))
        twos = given1_prob2[given].keys()
        print('  sum of given1_prob2[{}][w] is '.format(given),
              sum(given1_prob2[given][w] for w in twos))
        likely = sorted(twos,
                        key=lambda two:given1_prob2[given][two],
                        reverse=True)
        print('  next likely ', end='')
        for i in range(3):
            print('{}({:.4f}) '.format(likely[i],
                                       given1_prob2[given][likely[i]]),
                                       end='')
        print()
    print()

    print('check that P(x) = sum over y of P(x|y) * P(y)')
    testword = 'he'
    ones = given1_prob2.keys()
    prob_sum = sum(given1_prob2[one][testword] * p[one] for one in ones)
    print('  probability P("{}") is {:0.5f} '.format(testword, p[testword]))
    print('  sum of is P("{}"|y)*P(y) is {:0.5f}'.format(testword, prob_sum))
    print()

    print('check Bayes theorem P(y|x) = P(x|y)*P(y)/P(x) ')
    print('or P(x|y)*P(y) * P(y|x)*P(x)')
    # First find the conditional probabilities the other way 'round
    flipped_pairs = tuple( (two,one) for (one,two) in pairs )
    #print("pairs: ", pairs[:10])
    #print("flipped: ", flipped_pairs[:10])
    given2_prob1 = conditionals(flipped_pairs)
    # Then choose a pair to check 
    (x, y) = ('he', 'was')   # most common after 'he' from stuff above
    print('Choose  x="{}" y="{}"'.format(x,y))
    #print(' (x,y) in pairs: ', (x,y) in pairs)
    #print(' (y,x) in flipped: ', (y,x) in flipped_pairs)
    #print('***')
    #print({y:given1_prob2[x][y] for y in given1_prob2[x].keys() if given1_prob2[x][y] != 0})
    #print(given2_prob1[y])
    #print('***')
    print(' P(x) = ', p[x])  # (count of word 'he') / (count all words)
    print(' P(y) = ', p[y])  # (count of word 'was') / (count all words)
    print(' P(y|x) = ', given1_prob2[x][y], '= (count (he, was)) / (count (he, _))')
    print(' P(x|y) = ', given2_prob1[y][x], '= (count (he, was)) / (count (_, was))')
    print(' P(x) * P(y|x) = ', p[x] * given1_prob2[x][y])
    print(' P(y) * P(x|y) = ', p[y] * given2_prob1[y][x])
    # Both of these should be the explicit probability of that pair :
    # P(x & y) = (count of pairs (he was)) / (count of all pairs)
    print(' P(x & y) = ', sum(1 for (a,b) in pairs if a==x and b==y)/len(pairs))
    print()

    print('So there. Mnnnh.')
    
main()