In [None]:
import math
import numpy as np
import torch

def load_dataset(): 
    dataset = []
    with open("dataset","r") as f:
        for line in f:
            try:
                label, email = line.split('\t')
                email = email.replace("\n","").replace("."," ").replace(","," ").replace("?"," ").replace("!"," ").replace("`"," ").replace("'"," ").lower().split(" ")
                dataset.append([email, label])
            except: pass
    return dataset
            
dataset = load_dataset()

In [None]:
def build_vocabulary(dataset):
    vocabulary = {}
    word_index = 0
    for x,y in dataset:
        for word in x:
            if word not in vocabulary:
                vocabulary[word] = word_index
                word_index += 1
    return vocabulary

vocabulary = build_vocabulary(dataset)

In [None]:
def transform_dataset(dataset, vocabulary):
    X = np.zeros((len(dataset), len(vocabulary)))
    Y = np.zeros((len(dataset)))
    for i, (x,y) in enumerate(dataset):
        for word in x:
            X[i][vocabulary[word]] = 1
        if y == 'spam': Y[i] = 1
    return X, Y

X, y = transform_dataset(dataset, vocabulary)
trX, trY = X[:4000], y[:4000]
teX, teY = X[4000:], y[4000:]

In [None]:
trX = torch.tensor(trX).float()
trY = torch.tensor(trY).float()
teX = torch.tensor(teX).float()
teY = torch.tensor(teY).float()

In [None]:
model = torch.nn.Sequential()

model.add_module("linear", torch.nn.Linear(len(vocabulary), 50))
model.add_module("sigmoid", torch.nn.Sigmoid())
model.add_module("linear2", torch.nn.Linear(50, 1))
model.add_module("sigmoid2", torch.nn.Sigmoid())

criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1)

In [None]:
for epoch in range(500):
    output = model(trX).view(-1)
    loss = criterion(output, trY)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    acc = ((output > .5) == trY.byte()).sum().item() / float(len(trY))
    if epoch % 100 == 0: print(epoch, acc)    

In [None]:
with torch.no_grad():
    output = model(teX).view(-1)
    acc = ((output > .5) == teY.byte()).sum().item() / float(len(teY))
    print(acc)    