Setup¶

In [ ]:
# Setup - Run only once per Kernel App
%conda install openjdk -y

# install PySpark
%pip install pyspark==3.4.0

# install spark-nlp
%pip install spark-nlp==5.1.3

# restart kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")
Retrieving notices: ...working... done
Collecting package metadata (current_repodata.json): done
Solving environment: done


==> WARNING: A newer version of conda exists. <==
  current version: 23.3.1
  latest version: 23.10.0

Please update conda by running

    $ conda update -n base -c defaults conda

Or to minimize the number of packages updated during conda update use

     conda install conda=23.10.0



## Package Plan ##

  environment location: /opt/conda

  added / updated specs:
    - openjdk


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    ca-certificates-2023.08.22 |       h06a4308_0         123 KB
    certifi-2023.11.17         |  py310h06a4308_0         158 KB
    openjdk-11.0.13            |       h87a67e3_0       341.0 MB
    ------------------------------------------------------------
                                           Total:       341.3 MB

The following NEW packages will be INSTALLED:

  openjdk            pkgs/main/linux-64::openjdk-11.0.13-h87a67e3_0 

The following packages will be UPDATED:

  ca-certificates    conda-forge::ca-certificates-2023.7.2~ --> pkgs/main::ca-certificates-2023.08.22-h06a4308_0 
  certifi            conda-forge/noarch::certifi-2023.7.22~ --> pkgs/main/linux-64::certifi-2023.11.17-py310h06a4308_0 



Downloading and Extracting Packages
certifi-2023.11.17   | 158 KB    |                                       |   0% 
ca-certificates-2023 | 123 KB    |                                       |   0% 

openjdk-11.0.13      | 341.0 MB  |                                       |   0% 
certifi-2023.11.17   | 158 KB    | ##################################### | 100% 

openjdk-11.0.13      | 341.0 MB  | 4                                     |   1% 

openjdk-11.0.13      | 341.0 MB  | #9                                    |   5% 

openjdk-11.0.13      | 341.0 MB  | ####                                  |  11% 

openjdk-11.0.13      | 341.0 MB  | ######1                               |  17% 

openjdk-11.0.13      | 341.0 MB  | ########4                             |  23% 

openjdk-11.0.13      | 341.0 MB  | ##########6                           |  29% 

openjdk-11.0.13      | 341.0 MB  | ############5                         |  34% 

openjdk-11.0.13      | 341.0 MB  | ##############6                       |  40% 

openjdk-11.0.13      | 341.0 MB  | ################6                     |  45% 

openjdk-11.0.13      | 341.0 MB  | ##################7                   |  51% 

openjdk-11.0.13      | 341.0 MB  | ####################9                 |  57% 

openjdk-11.0.13      | 341.0 MB  | ######################9               |  62% 

openjdk-11.0.13      | 341.0 MB  | #########################             |  68% 

openjdk-11.0.13      | 341.0 MB  | ###########################1          |  73% 

openjdk-11.0.13      | 341.0 MB  | #############################1        |  79% 

openjdk-11.0.13      | 341.0 MB  | ###############################3      |  85% 

openjdk-11.0.13      | 341.0 MB  | #################################5    |  91% 

openjdk-11.0.13      | 341.0 MB  | ###################################6  |  96% 

                                                                                
                                                                                

                                                                                
Preparing transaction: done
Verifying transaction: done
Executing transaction: done

Note: you may need to restart the kernel to use updated packages.
Collecting pyspark==3.4.0
  Using cached pyspark-3.4.0-py2.py3-none-any.whl
Collecting py4j==0.10.9.7 (from pyspark==3.4.0)
  Using cached py4j-0.10.9.7-py2.py3-none-any.whl (200 kB)
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9.7 pyspark-3.4.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[notice] A new release of pip is available: 23.2.1 -> 23.3.1
[notice] To update, run: pip install --upgrade pip
Note: you may need to restart the kernel to use updated packages.
Collecting spark-nlp==5.1.3
  Obtaining dependency information for spark-nlp==5.1.3 from https://files.pythonhosted.org/packages/cd/7d/bc0eca4c9ec4c9c1d9b28c42c2f07942af70980a7d912d0aceebf8db32dd/spark_nlp-5.1.3-py2.py3-none-any.whl.metadata
  Using cached spark_nlp-5.1.3-py2.py3-none-any.whl.metadata (53 kB)
Using cached spark_nlp-5.1.3-py2.py3-none-any.whl (537 kB)
Installing collected packages: spark-nlp
Successfully installed spark-nlp-5.1.3
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[notice] A new release of pip is available: 23.2.1 -> 23.3.1
[notice] To update, run: pip install --upgrade pip
Note: you may need to restart the kernel to use updated packages.
Out[ ]:
In [ ]:
import sagemaker
session = sagemaker.Session()
bucket = session.default_bucket()

# Import pyspark and build Spark session
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder.appName("PySparkApp")
    .config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.2.2")
    .config(
        "fs.s3a.aws.credentials.provider",
        "com.amazonaws.auth.ContainerCredentialsProvider",
    )
    .getOrCreate()
)

print(spark.version)
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
Warning: Ignoring non-Spark config property: fs.s3a.aws.credentials.provider
:: loading settings :: url = jar:file:/opt/conda/lib/python3.10/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml
Ivy Default Cache set to: /root/.ivy2/cache
The jars for the packages stored in: /root/.ivy2/jars
org.apache.hadoop#hadoop-aws added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-9848f4a0-613c-4457-b48d-722c6db85236;1.0
	confs: [default]
	found org.apache.hadoop#hadoop-aws;3.2.2 in central
	found com.amazonaws#aws-java-sdk-bundle;1.11.563 in central
:: resolution report :: resolve 335ms :: artifacts dl 34ms
	:: modules in use:
	com.amazonaws#aws-java-sdk-bundle;1.11.563 from central in [default]
	org.apache.hadoop#hadoop-aws;3.2.2 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	---------------------------------------------------------------------
	|      default     |   2   |   0   |   0   |   0   ||   2   |   0   |
	---------------------------------------------------------------------
:: retrieving :: org.apache.spark#spark-submit-parent-9848f4a0-613c-4457-b48d-722c6db85236
	confs: [default]
	0 artifacts copied, 2 already retrieved (0kB/22ms)
23/11/30 19:37:14 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
3.4.0

Load in Comments Dataframes¶

In [ ]:
# Tegveer's S3 -- DO NOT CHANGE
s3_directory_comms = f"s3a://sagemaker-us-east-1-433974840707/project/ml_comments/"

# Read all the Parquet files in the directory into a DataFrame
df_comments = spark.read.parquet(s3_directory_comms)
23/11/30 19:37:18 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties
                                                                                
In [ ]:
print("Number of records in sampled and filtered df: ", df_comments.count())
[Stage 1:============================================>              (3 + 1) / 4]
Number of records in sampled and filtered df:  13242001
                                                                                
In [ ]:
df_comments.groupby('controversiality').count().show()
[Stage 4:============================================>              (3 + 1) / 4]
+----------------+--------+
|controversiality|   count|
+----------------+--------+
|               0|12311780|
|               1|  930221|
+----------------+--------+

                                                                                
In [ ]:
df_comments.printSchema()
root
 |-- controversiality: string (nullable = true)
 |-- distinguished: string (nullable = true)
 |-- subreddit: string (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- gilded: integer (nullable = true)
 |-- score: integer (nullable = true)
 |-- weight: double (nullable = true)

Load in models + Evaluate performance¶

In [ ]:
from pyspark.ml.feature import OneHotEncoder, StringIndexer, IndexToString, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, LogisticRegression, NaiveBayes
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml import Pipeline, Model
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import roc_curve, auc
from pyspark.ml.pipeline import PipelineModel
import matplotlib.pyplot as plt
import seaborn as sns
In [ ]:
# random forest
rf_path = "s3a://sagemaker-us-east-1-433974840707/project/ml_updated/rf/rf.model"
rf = PipelineModel.load(rf_path)

# log reg
lr_path = "s3a://sagemaker-us-east-1-433974840707/project/ml_updated/lr/lr.model"
lr = PipelineModel.load(lr_path)

# gradient boosted tree
gbt_path = "s3a://sagemaker-us-east-1-433974840707/project/ml_updated/gbt/gbt.model"
gbt = PipelineModel.load(gbt_path)

# svm
svm_path = "s3a://sagemaker-us-east-1-433974840707/project/ml_updated/svm/svm.model"
svm = PipelineModel.load(svm_path)
WARNING: An illegal reflective access operation has occurred                    
WARNING: Illegal reflective access by org.apache.spark.util.SizeEstimator$ (file:/opt/conda/lib/python3.10/site-packages/pyspark/jars/spark-core_2.12-3.4.0.jar) to field java.math.BigInteger.mag
WARNING: Please consider reporting this to the maintainers of org.apache.spark.util.SizeEstimator$
WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations
WARNING: All illegal access operations will be denied in a future release
                                                                                

Split data into train, test, and validation¶

In [ ]:
train_data, test_data = df_comments.randomSplit([0.75, 0.25], 24)
print("Number of training records: " + str(train_data.count()))
print("Number of testing records : " + str(test_data.count()))
#print("Number of validation records : " + str(val_data.count()))
train_data.cache()
                                                                                
Number of training records: 9931987
[Stage 121:==========================================>              (3 + 1) / 4]
Number of testing records : 3310014
                                                                                
Out[ ]:
DataFrame[controversiality: string, distinguished: string, subreddit: string, year: int, month: int, day: int, gilded: int, score: int, weight: double]

Random Forests¶

Fit Pipeline and Train Model¶

In [ ]:
predictions_train = rf.transform(train_data)
predictions_train.show(5)
[Stage 127:>                                                        (0 + 1) / 1]
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+--------------------+----------+-------------------------+
|controversiality|distinguished|   subreddit|year|month|day|gilded|score|            weight|controversiality_str|distinguished_ix|subreddit_ix|subreddit_vec|            features|       rawPrediction|         probability|prediction|predictedControversiality|
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+--------------------+----------+-------------------------+
|               0|           no|Ask_Politics|2021|    1|  1|     0|   -9|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[7.41897532339150...|[0.24729917744638...|       1.0|                        1|
|               0|           no|Ask_Politics|2021|    1|  1|     0|   -2|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[6.78799717249015...|[0.22626657241633...|       1.0|                        1|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[19.8916404614541...|[0.66305468204847...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[19.8916404614541...|[0.66305468204847...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[19.8916404614541...|[0.66305468204847...|       0.0|                        0|
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+--------------------+----------+-------------------------+
only showing top 5 rows

                                                                                

Model Train Results¶

In [ ]:
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions_train)
                                                                                
In [ ]:
print("Train Accuracy = %g" % accuracy)
print("Train Error = %g" % (1.0 - accuracy))
Train Accuracy = 0.841486
Train Error = 0.158514
In [ ]:
y_pred_train=predictions_train.select("prediction").collect()
y_orig_train=predictions_train.select("controversiality_str").collect()
                                                                                
In [ ]:
cm = confusion_matrix(y_orig_train, y_pred_train)
print("Confusion Matrix:")
print(cm)
Confusion Matrix:
[[8010715 1223583]
 [ 350780  346909]]

Model Test Results¶

In [ ]:
predictions_test = rf.transform(test_data)
In [ ]:
# Evaluate the model using ROC AUC
rf_fpr, rf_tpr, thresholds = roc_curve(predictions_test.select("controversiality_str").collect(), predictions_test.select("probability").rdd.map(lambda x: x[0][1]).collect())
rf_roc_auc = auc(rf_fpr, rf_tpr)
                                                                                
In [ ]:
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions_test)
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="f1")
f1_score = evaluator.evaluate(predictions_test)
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="precisionByLabel")
precision = evaluator.evaluate(predictions_test)
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="recallByLabel")
recall = evaluator.evaluate(predictions_test)
                                                                                
In [ ]:
print("Test Accuracy = %g" % accuracy)
print("Test Error = %g" % (1.0 - accuracy))
print(f"F1-score: {f1_score}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
Test Accuracy = 0.841065
Test Error = 0.158935
F1-score: 0.8677418254546316
Precision: 0.9579295902155427
Recall: 0.8671394341217917
In [ ]:
y_pred_test=predictions_test.select("prediction").collect()
y_orig_test=predictions_test.select("controversiality_str").collect()
                                                                                
In [ ]:
cm = confusion_matrix(y_orig_test, y_pred_test)
print("Confusion Matrix:")
print(cm)
Confusion Matrix:
[[2668606  408876]
 [ 117200  115332]]
In [ ]:
binary_evaluator = BinaryClassificationEvaluator(labelCol="controversiality_str", rawPredictionCol="prediction", metricName="areaUnderROC")
au_roc_test = binary_evaluator.evaluate(predictions_test)
                                                                                
In [ ]:
print("Test Area Under ROC = %g" % au_roc_test)
Test Area Under ROC = 0.681561
In [ ]:
sns.heatmap(cm, annot=True, fmt='d')
# Save plot
plt.title("Test Set Confusion Matrix - Random Forest")
plt.xlabel("True Label")
plt.ylabel("Predicted Label")  
plt.savefig('../../data/plots/test_conf_mtx_rf.png',bbox_inches='tight')
plt.savefig('../../website-source/test_conf_mtx_rf.png',bbox_inches='tight')
plt.show()
No description has been provided for this image

Logistic Regression¶

Model Train Results¶

In [ ]:
predictions_train = lr.transform(train_data)
predictions_train.show(5)
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+--------------------+----------+-------------------------+
|controversiality|distinguished|   subreddit|year|month|day|gilded|score|            weight|controversiality_str|distinguished_ix|subreddit_ix|subreddit_vec|            features|       rawPrediction|         probability|prediction|predictedControversiality|
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+--------------------+----------+-------------------------+
|               0|           no|Ask_Politics|2021|    1|  1|     0|   -9|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[-0.8074653411129...|[0.30843087898349...|       1.0|                        1|
|               0|           no|Ask_Politics|2021|    1|  1|     0|   -2|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[-0.1851716316184...|[0.45383891643314...|       1.0|                        1|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[0.08152567245059...|[0.52037013696251...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[0.08152567245059...|[0.52037013696251...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[0.08152567245059...|[0.52037013696251...|       0.0|                        0|
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+--------------------+----------+-------------------------+
only showing top 5 rows

23/11/30 05:20:59 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
In [ ]:
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions_train)
                                                                                
In [ ]:
print("Train Accuracy = %g" % accuracy)
print("Train Error = %g" % (1.0 - accuracy))
Train Accuracy = 0.451084
Train Error = 0.548916
In [ ]:
y_pred_train=predictions_train.select("prediction").collect()
y_orig_train=predictions_train.select("controversiality_str").collect()
                                                                                
In [ ]:
cm = confusion_matrix(y_orig_train, y_pred_train)
print("Confusion Matrix:")
print(cm)
Confusion Matrix:
[[3946568 5287730]
 [ 164093  533596]]

Model Test Results¶

In [ ]:
predictions_test = lr.transform(test_data)
In [ ]:
# Evaluate the model using ROC AUC
l_fpr, l_tpr, l_thresholds = roc_curve(predictions_test.select("controversiality_str").collect(), predictions_test.select("probability").rdd.map(lambda x: x[0][1]).collect())
l_roc_auc = auc(l_fpr, l_tpr)
                                                                                
In [ ]:
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions_test)
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="f1")
f1_score = evaluator.evaluate(predictions_test)
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="precisionByLabel")
precision = evaluator.evaluate(predictions_test)
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="recallByLabel")
recall = evaluator.evaluate(predictions_test)
23/11/30 19:47:17 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
                                                                                
In [ ]:
print("Test Accuracy = %g" % accuracy)
print("Test Error = %g" % (1.0 - accuracy))
print(f"F1-score: {f1_score}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
Test Accuracy = 0.451027
Test Error = 0.548973
F1-score: 0.5613807552998148
Precision: 0.9599576380976849
Recall: 0.42737439244161296
In [ ]:
y_pred_test=predictions_test.select("prediction").collect()
y_orig_test=predictions_test.select("controversiality_str").collect()
                                                                                
In [ ]:
cm = confusion_matrix(y_orig_test, y_pred_test)
print("Confusion Matrix:")
print(cm)
Confusion Matrix:
[[1315237 1762245]
 [  54862  177670]]
In [ ]:
sns.heatmap(cm, annot=True, fmt='d')
# Save plot
plt.title("Test Set Confusion Matrix - Logistic Regression")
plt.xlabel("True Label")
plt.ylabel("Predicted Label")  
plt.savefig('../../data/plots/test_conf_mtx_lr.png',bbox_inches='tight')
plt.savefig('../../website-source/test_conf_mtx_lr.png',bbox_inches='tight')
plt.show()
No description has been provided for this image

Gradient Boosted Trees¶

Model Train Results¶

In [ ]:
predictions_train = gbt.transform(train_data)
predictions_train.show(5)
[Stage 146:>                                                        (0 + 1) / 1]
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+--------------------+----------+-------------------------+
|controversiality|distinguished|   subreddit|year|month|day|gilded|score|            weight|controversiality_str|distinguished_ix|subreddit_ix|subreddit_vec|            features|       rawPrediction|         probability|prediction|predictedControversiality|
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+--------------------+----------+-------------------------+
|               0|           no|Ask_Politics|2021|    1|  1|     0|   -9|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[0.20339042685072...|[0.60031573735383...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  1|     0|   -2|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[-0.6891155405878...|[0.20129324612587...|       1.0|                        1|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[0.76012771693484...|[0.82057609144857...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[0.76012771693484...|[0.82057609144857...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[0.76012771693484...|[0.82057609144857...|       0.0|                        0|
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+--------------------+----------+-------------------------+
only showing top 5 rows

                                                                                
In [ ]:
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions_train)
                                                                                
In [ ]:
print("Train Accuracy = %g" % accuracy)
print("Train Error = %g" % (1.0 - accuracy))
Train Accuracy = 0.66775
Train Error = 0.33225
In [ ]:
y_pred_train=predictions_train.select("prediction").collect()
y_orig_train=predictions_train.select("controversiality_str").collect()
                                                                                
In [ ]:
cm = confusion_matrix(y_orig_train, y_pred_train)
print("Confusion Matrix:")
print(cm)
Confusion Matrix:
[[6119087 3115211]
 [ 184696  512993]]

Model Test Results¶

In [ ]:
predictions_test = gbt.transform(test_data)
In [ ]:
# Evaluate the model using ROC AUC
gbt_fpr, gbt_tpr, gbt_thresholds = roc_curve(predictions_test.select("controversiality_str").collect(), predictions_test.select("probability").rdd.map(lambda x: x[0][1]).collect())
gbt_roc_auc = auc(gbt_fpr, gbt_tpr)
                                                                                
In [ ]:
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions_test)
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="f1")
f1_score = evaluator.evaluate(predictions_test)
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="precisionByLabel")
precision = evaluator.evaluate(predictions_test)
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="recallByLabel")
recall = evaluator.evaluate(predictions_test)
                                                                                
In [ ]:
print("Test Accuracy = %g" % accuracy)
print("Test Error = %g" % (1.0 - accuracy))
print(f"F1-score: {f1_score}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
Test Accuracy = 0.667229
Test Error = 0.332771
F1-score: 0.7485596614154175
Precision: 0.9706071200617313
Recall: 0.6621361229732619
In [ ]:
y_pred_test=predictions_test.select("prediction").collect()
y_orig_test=predictions_test.select("controversiality_str").collect()
                                                                                
In [ ]:
cm = confusion_matrix(y_orig_test, y_pred_test)
print("Confusion Matrix:")
print(cm)
Confusion Matrix:
[[2037712 1039770]
 [  61708  170824]]
In [ ]:
sns.heatmap(cm, annot=True, fmt='d')
# Save plot
plt.title("Test Set Confusion Matrix - Gradient Boosted Trees")
plt.xlabel("True Label")
plt.ylabel("Predicted Label")  
plt.savefig('../../data/plots/test_conf_mtx_gbt.png',bbox_inches='tight')
plt.savefig('../../website-source/test_conf_mtx_gbt.png',bbox_inches='tight')
plt.show()
No description has been provided for this image

Support Vector Machines¶

Model Train Results¶

In [ ]:
predictions_train = svm.transform(train_data)
predictions_train.show(5)
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+----------+-------------------------+
|controversiality|distinguished|   subreddit|year|month|day|gilded|score|            weight|controversiality_str|distinguished_ix|subreddit_ix|subreddit_vec|            features|       rawPrediction|prediction|predictedControversiality|
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+----------+-------------------------+
|               0|           no|Ask_Politics|2021|    1|  1|     0|   -9|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[-1.4294100287020...|       1.0|                        1|
|               0|           no|Ask_Politics|2021|    1|  1|     0|   -2|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[-0.4892141130452...|       1.0|                        1|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[-0.0862730063351...|       1.0|                        1|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[-0.0862730063351...|       1.0|                        1|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[-0.0862730063351...|       1.0|                        1|
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+----------+-------------------------+
only showing top 5 rows

In [ ]:
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions_train)
                                                                                
In [ ]:
print("Train Accuracy = %g" % accuracy)
print("Train Error = %g" % (1.0 - accuracy))
Train Accuracy = 0.341534
Train Error = 0.658466
In [ ]:
y_pred_train=predictions_train.select("prediction").collect()
y_orig_train=predictions_train.select("controversiality_str").collect()
                                                                                
In [ ]:
cm = confusion_matrix(y_orig_train, y_pred_train)
print("Confusion Matrix:")
print(cm)
Confusion Matrix:
[[2773428 6460870]
 [  79004  618685]]

Model Test Results¶

In [ ]:
predictions_test = svm.transform(test_data)
In [ ]:
predictions_test.show()
[Stage 195:>                                                        (0 + 1) / 1]
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+----------+-------------------------+
|controversiality|distinguished|   subreddit|year|month|day|gilded|score|            weight|controversiality_str|distinguished_ix|subreddit_ix|subreddit_vec|            features|       rawPrediction|prediction|predictedControversiality|
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+----------+-------------------------+
|               0|           no|Ask_Politics|2021|    1|  1|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[-0.0862730063351...|       1.0|                        1|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    2|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[0.04804069590150...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    4|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[0.31666810037489...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    6|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[0.58529550484826...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  1|     0|    6|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[0.58529550484826...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  1|     0|   59|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,1...|[7.70392172339268...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  2|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,2...|[-0.0862261073123...|       1.0|                        1|
|               0|           no|Ask_Politics|2021|    1|  2|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,2...|[-0.0862261073123...|       1.0|                        1|
|               0|           no|Ask_Politics|2021|    1|  2|     0|    2|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,2...|[0.04808759492436...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  2|     0|   12|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,2...|[1.39122461729124...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  3|     0|    1|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,3...|[-0.0861792082894...|       1.0|                        1|
|               0|           no|Ask_Politics|2021|    1|  3|     0|    2|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,3...|[0.04813449394725...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  3|     0|    2|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,3...|[0.04813449394725...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  3|     0|    3|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,3...|[0.18244819618396...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  3|     0|    3|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,3...|[0.18244819618396...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  3|     0|    4|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,3...|[0.31676189842065...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  3|     0|    6|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,3...|[0.58538930289401...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  3|     0|   12|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,3...|[1.39127151631413...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  3|     0|   71|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,3...|[9.31577994827867...|       0.0|                        0|
|               0|           no|Ask_Politics|2021|    1|  4|     0|   -6|0.5377776812126273|                 0.0|             0.0|         8.0|    (8,[],[])|[0.0,2021.0,1.0,4...|[-1.0263282249233...|       1.0|                        1|
+----------------+-------------+------------+----+-----+---+------+-----+------------------+--------------------+----------------+------------+-------------+--------------------+--------------------+----------+-------------------------+
only showing top 20 rows

                                                                                
In [ ]:
# Evaluate the model using ROC AUC
svm_fpr, svm_tpr, svm_thresholds = roc_curve(predictions_test.select("controversiality_str").collect(), predictions_test.select("probability").rdd.map(lambda x: x[0][1]).collect())
svm_roc_auc = auc(svm_fpr, svm_tpr)
                                                                                
---------------------------------------------------------------------------
AnalysisException                         Traceback (most recent call last)
Cell In[55], line 2
      1 # Evaluate the model using ROC AUC
----> 2 svm_fpr, svm_tpr, svm_thresholds = roc_curve(predictions_test.select("controversiality_str").collect(), predictions_test.select("probability").rdd.map(lambda x: x[0][1]).collect())
      3 svm_roc_auc = auc(svm_fpr, svm_tpr)

File /opt/conda/lib/python3.10/site-packages/pyspark/sql/dataframe.py:3036, in DataFrame.select(self, *cols)
   2991 def select(self, *cols: "ColumnOrName") -> "DataFrame":  # type: ignore[misc]
   2992     """Projects a set of expressions and returns a new :class:`DataFrame`.
   2993 
   2994     .. versionadded:: 1.3.0
   (...)
   3034     +-----+---+
   3035     """
-> 3036     jdf = self._jdf.select(self._jcols(*cols))
   3037     return DataFrame(jdf, self.sparkSession)

File /opt/conda/lib/python3.10/site-packages/py4j/java_gateway.py:1322, in JavaMember.__call__(self, *args)
   1316 command = proto.CALL_COMMAND_NAME +\
   1317     self.command_header +\
   1318     args_command +\
   1319     proto.END_COMMAND_PART
   1321 answer = self.gateway_client.send_command(command)
-> 1322 return_value = get_return_value(
   1323     answer, self.gateway_client, self.target_id, self.name)
   1325 for temp_arg in temp_args:
   1326     if hasattr(temp_arg, "_detach"):

File /opt/conda/lib/python3.10/site-packages/pyspark/errors/exceptions/captured.py:175, in capture_sql_exception.<locals>.deco(*a, **kw)
    171 converted = convert_exception(e.java_exception)
    172 if not isinstance(converted, UnknownException):
    173     # Hide where the exception came from that shows a non-Pythonic
    174     # JVM exception message.
--> 175     raise converted from None
    176 else:
    177     raise

AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `probability` cannot be resolved. Did you mean one of the following? [`prediction`, `subreddit`, `controversiality`, `day`, `gilded`].;
'Project ['probability]
+- Project [controversiality#0, distinguished#1, subreddit#2, year#3, month#4, day#5, gilded#6, score#7, weight#8, controversiality_str#5033, distinguished_ix#5050, subreddit_ix#5068, subreddit_vec#5088, features#5114, rawPrediction#5131, prediction#5150, UDF(cast(prediction#5150 as double)) AS predictedControversiality#5188]
   +- Project [controversiality#0, distinguished#1, subreddit#2, year#3, month#4, day#5, gilded#6, score#7, weight#8, controversiality_str#5033, distinguished_ix#5050, subreddit_ix#5068, subreddit_vec#5088, features#5114, rawPrediction#5131, UDF(rawPrediction#5131) AS prediction#5150]
      +- Project [controversiality#0, distinguished#1, subreddit#2, year#3, month#4, day#5, gilded#6, score#7, weight#8, controversiality_str#5033, distinguished_ix#5050, subreddit_ix#5068, subreddit_vec#5088, features#5114, UDF(features#5114) AS rawPrediction#5131]
         +- Project [controversiality#0, distinguished#1, subreddit#2, year#3, month#4, day#5, gilded#6, score#7, weight#8, controversiality_str#5033, distinguished_ix#5050, subreddit_ix#5068, subreddit_vec#5088, UDF(struct(distinguished_ix, distinguished_ix#5050, year_double_VectorAssembler_c2a03cfc9ed0, cast(year#3 as double), month_double_VectorAssembler_c2a03cfc9ed0, cast(month#4 as double), day_double_VectorAssembler_c2a03cfc9ed0, cast(day#5 as double), score_double_VectorAssembler_c2a03cfc9ed0, cast(score#7 as double), gilded_double_VectorAssembler_c2a03cfc9ed0, cast(gilded#6 as double), subreddit_ix, subreddit_ix#5068)) AS features#5114]
            +- Project [controversiality#0, distinguished#1, subreddit#2, year#3, month#4, day#5, gilded#6, score#7, weight#8, controversiality_str#5033, distinguished_ix#5050, subreddit_ix#5068, UDF(cast(subreddit_ix#5068 as double), 0) AS subreddit_vec#5088]
               +- Project [controversiality#0, distinguished#1, subreddit#2, year#3, month#4, day#5, gilded#6, score#7, weight#8, controversiality_str#5033, distinguished_ix#5050, UDF(cast(subreddit#2 as string)) AS subreddit_ix#5068]
                  +- Project [controversiality#0, distinguished#1, subreddit#2, year#3, month#4, day#5, gilded#6, score#7, weight#8, controversiality_str#5033, UDF(cast(distinguished#1 as string)) AS distinguished_ix#5050]
                     +- Project [controversiality#0, distinguished#1, subreddit#2, year#3, month#4, day#5, gilded#6, score#7, weight#8, UDF(cast(controversiality#0 as string)) AS controversiality_str#5033]
                        +- Sample 0.75, 1.0, false, 24
                           +- Sort [controversiality#0 ASC NULLS FIRST, distinguished#1 ASC NULLS FIRST, subreddit#2 ASC NULLS FIRST, year#3 ASC NULLS FIRST, month#4 ASC NULLS FIRST, day#5 ASC NULLS FIRST, gilded#6 ASC NULLS FIRST, score#7 ASC NULLS FIRST, weight#8 ASC NULLS FIRST], false
                              +- Relation [controversiality#0,distinguished#1,subreddit#2,year#3,month#4,day#5,gilded#6,score#7,weight#8] parquet
In [ ]:
evaluator = MulticlassClassificationEvaluator(labelCol="controversiality_str", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions_test)
f1_score = evaluator.evaluate(predictions_test)
precision = evaluator.evaluate(predictions_test)
recall = evaluator.evaluate(predictions_test)
                                                                                
In [ ]:
print("Test Accuracy = %g" % accuracy)
print("Test Error = %g" % (1.0 - accuracy))
print(f"F1-score: {f1_score}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
Test Accuracy = 0.341289
Test Error = 0.658711
F1-score: 0.3412888888083253
Precision: 0.3412888888083253
Recall: 0.3412888888083253
In [ ]:
y_pred_test=predictions_test.select("prediction").collect()
y_orig_test=predictions_test.select("controversiality_str").collect()
                                                                                
In [ ]:
cm = confusion_matrix(y_orig_test, y_pred_test)
print("Confusion Matrix:")
print(cm)
Confusion Matrix:
[[ 923640 2153842]
 [  26501  206031]]
In [ ]:
sns.heatmap(cm, annot=True, fmt='d')
# Save plot
plt.title("Test Set Confusion Matrix - Support Vector Machines")
plt.xlabel("True Label")
plt.ylabel("Predicted Label")  
plt.savefig('../../data/plots/test_conf_mtx_svm.png',bbox_inches='tight')
plt.savefig('../../website-source/test_conf_mtx_svm.png',bbox_inches='tight')
plt.show()
No description has been provided for this image

Test Set AUC-ROC Curve Plot¶

In [ ]:
# Create a plot of the ROC curve
plt.figure(figsize=(8,6))
lw = 2
plt.plot(rf_fpr, rf_tpr, color='darkgreen', lw=lw, label='Random Forest (area = %0.2f)' % rf_roc_auc)
plt.plot(l_fpr, l_tpr, color='darkorange', lw=lw, label='Logistic Regression (area = %0.2f)' % l_roc_auc)
plt.plot(gbt_fpr, gbt_tpr, color='darkred', lw=lw, label='Gradient Boosted Trees (area = %0.2f)' % gbt_roc_auc)
# plt.plot(svm_fpr, svm_tpr, color='darkblue', lw=lw, label='Support Vector Machines (area = %0.2f)' % svm_roc_auc)
plt.plot([0, 1], [0, 1], color='darkgoldenrod', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Test Set Receiver Operating Curve - Classifier Models')
plt.legend(loc="lower right")

# Save plot
plt.savefig('../../data/plots/AUC-ROC-controv.png',bbox_inches='tight')
plt.savefig('../../website-source/AUC-ROC-controv.png',bbox_inches='tight')

plt.show()
No description has been provided for this image
In [ ]: