Code: ML-Topic 8

# Setup - Run only once per Kernel App
%conda install openjdk -y

# install PySpark
%pip install pyspark==3.4.0

# install spark-nlp
%pip install spark-nlp==5.1.3
%pip install sparknlp

# install plotly
%pip install plotly

# restart kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")
Collecting package metadata (current_repodata.json): done
Solving environment: done


==> WARNING: A newer version of conda exists. <==
  current version: 23.3.1
  latest version: 23.10.0

Please update conda by running

    $ conda update -n base -c defaults conda

Or to minimize the number of packages updated during conda update use

     conda install conda=23.10.0



## Package Plan ##

  environment location: /opt/conda

  added / updated specs:
    - openjdk


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    ca-certificates-2023.08.22 |       h06a4308_0         123 KB
    certifi-2023.11.17         |  py310h06a4308_0         158 KB
    openjdk-11.0.13            |       h87a67e3_0       341.0 MB
    ------------------------------------------------------------
                                           Total:       341.3 MB

The following NEW packages will be INSTALLED:

  openjdk            pkgs/main/linux-64::openjdk-11.0.13-h87a67e3_0 

The following packages will be UPDATED:

  ca-certificates    conda-forge::ca-certificates-2023.7.2~ --> pkgs/main::ca-certificates-2023.08.22-h06a4308_0 
  certifi            conda-forge/noarch::certifi-2023.7.22~ --> pkgs/main/linux-64::certifi-2023.11.17-py310h06a4308_0 



Downloading and Extracting Packages
certifi-2023.11.17   | 158 KB    |                                       |   0% 
ca-certificates-2023 | 123 KB    |                                       |   0% 

openjdk-11.0.13      | 341.0 MB  |                                       |   0% 
certifi-2023.11.17   | 158 KB    | ##################################### | 100% 

openjdk-11.0.13      | 341.0 MB  |                                       |   0% 

openjdk-11.0.13      | 341.0 MB  | #1                                    |   3% 

openjdk-11.0.13      | 341.0 MB  | ##5                                   |   7% 

openjdk-11.0.13      | 341.0 MB  | ###6                                  |  10% 

openjdk-11.0.13      | 341.0 MB  | ####7                                 |  13% 

openjdk-11.0.13      | 341.0 MB  | #####7                                |  16% 

openjdk-11.0.13      | 341.0 MB  | ######6                               |  18% 

openjdk-11.0.13      | 341.0 MB  | ########2                             |  22% 

openjdk-11.0.13      | 341.0 MB  | #########7                            |  26% 

openjdk-11.0.13      | 341.0 MB  | ###########3                          |  31% 

openjdk-11.0.13      | 341.0 MB  | ############7                         |  34% 

openjdk-11.0.13      | 341.0 MB  | ##############                        |  38% 

openjdk-11.0.13      | 341.0 MB  | ###############2                      |  41% 

openjdk-11.0.13      | 341.0 MB  | ################4                     |  45% 

openjdk-11.0.13      | 341.0 MB  | #################7                    |  48% 

openjdk-11.0.13      | 341.0 MB  | ##################9                   |  51% 

openjdk-11.0.13      | 341.0 MB  | ####################1                 |  54% 

openjdk-11.0.13      | 341.0 MB  | #####################3                |  58% 

openjdk-11.0.13      | 341.0 MB  | ######################5               |  61% 

openjdk-11.0.13      | 341.0 MB  | #######################8              |  64% 

openjdk-11.0.13      | 341.0 MB  | #########################             |  68% 

openjdk-11.0.13      | 341.0 MB  | ##########################2           |  71% 

openjdk-11.0.13      | 341.0 MB  | ###########################3          |  74% 

openjdk-11.0.13      | 341.0 MB  | ############################7         |  78% 

openjdk-11.0.13      | 341.0 MB  | ##############################2       |  82% 

openjdk-11.0.13      | 341.0 MB  | ###############################5      |  85% 

openjdk-11.0.13      | 341.0 MB  | ################################7     |  89% 

openjdk-11.0.13      | 341.0 MB  | #################################9    |  92% 

openjdk-11.0.13      | 341.0 MB  | ###################################4  |  96% 

openjdk-11.0.13      | 341.0 MB  | ##################################### | 100% 

                                                                                
                                                                                

                                                                                
Preparing transaction: done
Verifying transaction: done
Executing transaction: done

Note: you may need to restart the kernel to use updated packages.
Collecting pyspark==3.4.0
  Using cached pyspark-3.4.0-py2.py3-none-any.whl
Collecting py4j==0.10.9.7 (from pyspark==3.4.0)
  Using cached py4j-0.10.9.7-py2.py3-none-any.whl (200 kB)
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9.7 pyspark-3.4.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[notice] A new release of pip is available: 23.2.1 -> 23.3.1
[notice] To update, run: pip install --upgrade pip
Note: you may need to restart the kernel to use updated packages.
Collecting spark-nlp==5.1.3
  Obtaining dependency information for spark-nlp==5.1.3 from https://files.pythonhosted.org/packages/cd/7d/bc0eca4c9ec4c9c1d9b28c42c2f07942af70980a7d912d0aceebf8db32dd/spark_nlp-5.1.3-py2.py3-none-any.whl.metadata
  Using cached spark_nlp-5.1.3-py2.py3-none-any.whl.metadata (53 kB)
Using cached spark_nlp-5.1.3-py2.py3-none-any.whl (537 kB)
Installing collected packages: spark-nlp
Successfully installed spark-nlp-5.1.3
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[notice] A new release of pip is available: 23.2.1 -> 23.3.1
[notice] To update, run: pip install --upgrade pip
Note: you may need to restart the kernel to use updated packages.
Collecting sparknlp
  Using cached sparknlp-1.0.0-py3-none-any.whl (1.4 kB)
Requirement already satisfied: spark-nlp in /opt/conda/lib/python3.10/site-packages (from sparknlp) (5.1.3)
Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from sparknlp) (1.26.0)
Installing collected packages: sparknlp
Successfully installed sparknlp-1.0.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[notice] A new release of pip is available: 23.2.1 -> 23.3.1
[notice] To update, run: pip install --upgrade pip
Note: you may need to restart the kernel to use updated packages.
Requirement already satisfied: plotly in /opt/conda/lib/python3.10/site-packages (5.9.0)
Requirement already satisfied: tenacity>=6.2.0 in /opt/conda/lib/python3.10/site-packages (from plotly) (8.0.1)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[notice] A new release of pip is available: 23.2.1 -> 23.3.1
[notice] To update, run: pip install --upgrade pip
Note: you may need to restart the kernel to use updated packages.
## Import packages
import json
import sparknlp
import numpy as np
import pandas as pd
from sparknlp.base import *
from pyspark.ml import Pipeline
from sparknlp.annotator import *
import pyspark.sql.functions as F
from pyspark.sql.types import IntegerType, StringType
from pyspark.mllib.evaluation import BinaryClassificationMetrics
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.feature import MinMaxScaler, StringIndexer, VectorIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier, NaiveBayes 
from pyspark.sql.functions import monotonically_increasing_id, row_number, mean, stddev, max, min, count, percentile_approx, year, month, dayofmonth, ceil, col, dayofweek, hour, explode, date_format, lower, size, split, regexp_replace, isnan, when
from pyspark.sql import SparkSession
from sparknlp.pretrained import PretrainedPipeline
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.subplots as sp
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "plotly_mimetype+notebook_connected"
from pyspark.sql import SparkSession
from py4j.java_gateway import java_import

spark = SparkSession.builder \
    .appName("Spark ML")\
    .config("spark.driver.memory","16G")\
    .config("spark.driver.maxResultSize", "0") \
    .config("spark.kryoserializer.buffer.max", "2000M")\
    .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.1.3,org.apache.hadoop:hadoop-aws:3.2.2")\
    .config(
        "fs.s3a.aws.credentials.provider",
        "com.amazonaws.auth.ContainerCredentialsProvider",
    )\
    .getOrCreate()

print(f"Spark version: {spark.version}")
print(f"sparknlp version: {sparknlp.version()}")
Warning: Ignoring non-Spark config property: fs.s3a.aws.credentials.provider
:: loading settings :: url = jar:file:/opt/conda/lib/python3.10/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml
Ivy Default Cache set to: /root/.ivy2/cache
The jars for the packages stored in: /root/.ivy2/jars
com.johnsnowlabs.nlp#spark-nlp_2.12 added as a dependency
org.apache.hadoop#hadoop-aws added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-a6a3ea83-f9a7-468f-9a5c-b7a8a4c19ea7;1.0
    confs: [default]
    found com.johnsnowlabs.nlp#spark-nlp_2.12;5.1.3 in central
    found com.typesafe#config;1.4.2 in central
    found org.rocksdb#rocksdbjni;6.29.5 in central
    found com.amazonaws#aws-java-sdk-bundle;1.11.828 in central
    found com.github.universal-automata#liblevenshtein;3.0.0 in central
    found com.google.protobuf#protobuf-java-util;3.0.0-beta-3 in central
    found com.google.protobuf#protobuf-java;3.0.0-beta-3 in central
    found com.google.code.gson#gson;2.3 in central
    found it.unimi.dsi#fastutil;7.0.12 in central
    found org.projectlombok#lombok;1.16.8 in central
    found com.google.cloud#google-cloud-storage;2.20.1 in central
    found com.google.guava#guava;31.1-jre in central
    found com.google.guava#failureaccess;1.0.1 in central
    found com.google.guava#listenablefuture;9999.0-empty-to-avoid-conflict-with-guava in central
    found com.google.errorprone#error_prone_annotations;2.18.0 in central
    found com.google.j2objc#j2objc-annotations;1.3 in central
    found com.google.http-client#google-http-client;1.43.0 in central
    found io.opencensus#opencensus-contrib-http-util;0.31.1 in central
    found com.google.http-client#google-http-client-jackson2;1.43.0 in central
    found com.google.http-client#google-http-client-gson;1.43.0 in central
    found com.google.api-client#google-api-client;2.2.0 in central
    found commons-codec#commons-codec;1.15 in central
    found com.google.oauth-client#google-oauth-client;1.34.1 in central
    found com.google.http-client#google-http-client-apache-v2;1.43.0 in central
    found com.google.apis#google-api-services-storage;v1-rev20220705-2.0.0 in central
    found com.google.code.gson#gson;2.10.1 in central
    found com.google.cloud#google-cloud-core;2.12.0 in central
    found io.grpc#grpc-context;1.53.0 in central
    found com.google.auto.value#auto-value-annotations;1.10.1 in central
    found com.google.auto.value#auto-value;1.10.1 in central
    found javax.annotation#javax.annotation-api;1.3.2 in central
    found commons-logging#commons-logging;1.2 in central
    found com.google.cloud#google-cloud-core-http;2.12.0 in central
    found com.google.http-client#google-http-client-appengine;1.43.0 in central
    found com.google.api#gax-httpjson;0.108.2 in central
    found com.google.cloud#google-cloud-core-grpc;2.12.0 in central
    found io.grpc#grpc-alts;1.53.0 in central
    found io.grpc#grpc-grpclb;1.53.0 in central
    found org.conscrypt#conscrypt-openjdk-uber;2.5.2 in central
    found io.grpc#grpc-auth;1.53.0 in central
    found io.grpc#grpc-protobuf;1.53.0 in central
    found io.grpc#grpc-protobuf-lite;1.53.0 in central
    found io.grpc#grpc-core;1.53.0 in central
    found com.google.api#gax;2.23.2 in central
    found com.google.api#gax-grpc;2.23.2 in central
    found com.google.auth#google-auth-library-credentials;1.16.0 in central
    found com.google.auth#google-auth-library-oauth2-http;1.16.0 in central
    found com.google.api#api-common;2.6.2 in central
    found io.opencensus#opencensus-api;0.31.1 in central
    found com.google.api.grpc#proto-google-iam-v1;1.9.2 in central
    found com.google.protobuf#protobuf-java;3.21.12 in central
    found com.google.protobuf#protobuf-java-util;3.21.12 in central
    found com.google.api.grpc#proto-google-common-protos;2.14.2 in central
    found org.threeten#threetenbp;1.6.5 in central
    found com.google.api.grpc#proto-google-cloud-storage-v2;2.20.1-alpha in central
    found com.google.api.grpc#grpc-google-cloud-storage-v2;2.20.1-alpha in central
    found com.google.api.grpc#gapic-google-cloud-storage-v2;2.20.1-alpha in central
    found com.fasterxml.jackson.core#jackson-core;2.14.2 in central
    found com.google.code.findbugs#jsr305;3.0.2 in central
    found io.grpc#grpc-api;1.53.0 in central
    found io.grpc#grpc-stub;1.53.0 in central
    found org.checkerframework#checker-qual;3.31.0 in central
    found io.perfmark#perfmark-api;0.26.0 in central
    found com.google.android#annotations;4.1.1.4 in central
    found org.codehaus.mojo#animal-sniffer-annotations;1.22 in central
    found io.opencensus#opencensus-proto;0.2.0 in central
    found io.grpc#grpc-services;1.53.0 in central
    found com.google.re2j#re2j;1.6 in central
    found io.grpc#grpc-netty-shaded;1.53.0 in central
    found io.grpc#grpc-googleapis;1.53.0 in central
    found io.grpc#grpc-xds;1.53.0 in central
    found com.navigamez#greex;1.0 in central
    found dk.brics.automaton#automaton;1.11-8 in central
    found com.johnsnowlabs.nlp#tensorflow-cpu_2.12;0.4.4 in central
    found com.microsoft.onnxruntime#onnxruntime;1.15.0 in central
    found org.apache.hadoop#hadoop-aws;3.2.2 in central
:: resolution report :: resolve 6615ms :: artifacts dl 828ms
    :: modules in use:
    com.amazonaws#aws-java-sdk-bundle;1.11.828 from central in [default]
    com.fasterxml.jackson.core#jackson-core;2.14.2 from central in [default]
    com.github.universal-automata#liblevenshtein;3.0.0 from central in [default]
    com.google.android#annotations;4.1.1.4 from central in [default]
    com.google.api#api-common;2.6.2 from central in [default]
    com.google.api#gax;2.23.2 from central in [default]
    com.google.api#gax-grpc;2.23.2 from central in [default]
    com.google.api#gax-httpjson;0.108.2 from central in [default]
    com.google.api-client#google-api-client;2.2.0 from central in [default]
    com.google.api.grpc#gapic-google-cloud-storage-v2;2.20.1-alpha from central in [default]
    com.google.api.grpc#grpc-google-cloud-storage-v2;2.20.1-alpha from central in [default]
    com.google.api.grpc#proto-google-cloud-storage-v2;2.20.1-alpha from central in [default]
    com.google.api.grpc#proto-google-common-protos;2.14.2 from central in [default]
    com.google.api.grpc#proto-google-iam-v1;1.9.2 from central in [default]
    com.google.apis#google-api-services-storage;v1-rev20220705-2.0.0 from central in [default]
    com.google.auth#google-auth-library-credentials;1.16.0 from central in [default]
    com.google.auth#google-auth-library-oauth2-http;1.16.0 from central in [default]
    com.google.auto.value#auto-value;1.10.1 from central in [default]
    com.google.auto.value#auto-value-annotations;1.10.1 from central in [default]
    com.google.cloud#google-cloud-core;2.12.0 from central in [default]
    com.google.cloud#google-cloud-core-grpc;2.12.0 from central in [default]
    com.google.cloud#google-cloud-core-http;2.12.0 from central in [default]
    com.google.cloud#google-cloud-storage;2.20.1 from central in [default]
    com.google.code.findbugs#jsr305;3.0.2 from central in [default]
    com.google.code.gson#gson;2.10.1 from central in [default]
    com.google.errorprone#error_prone_annotations;2.18.0 from central in [default]
    com.google.guava#failureaccess;1.0.1 from central in [default]
    com.google.guava#guava;31.1-jre from central in [default]
    com.google.guava#listenablefuture;9999.0-empty-to-avoid-conflict-with-guava from central in [default]
    com.google.http-client#google-http-client;1.43.0 from central in [default]
    com.google.http-client#google-http-client-apache-v2;1.43.0 from central in [default]
    com.google.http-client#google-http-client-appengine;1.43.0 from central in [default]
    com.google.http-client#google-http-client-gson;1.43.0 from central in [default]
    com.google.http-client#google-http-client-jackson2;1.43.0 from central in [default]
    com.google.j2objc#j2objc-annotations;1.3 from central in [default]
    com.google.oauth-client#google-oauth-client;1.34.1 from central in [default]
    com.google.protobuf#protobuf-java;3.21.12 from central in [default]
    com.google.protobuf#protobuf-java-util;3.21.12 from central in [default]
    com.google.re2j#re2j;1.6 from central in [default]
    com.johnsnowlabs.nlp#spark-nlp_2.12;5.1.3 from central in [default]
    com.johnsnowlabs.nlp#tensorflow-cpu_2.12;0.4.4 from central in [default]
    com.microsoft.onnxruntime#onnxruntime;1.15.0 from central in [default]
    com.navigamez#greex;1.0 from central in [default]
    com.typesafe#config;1.4.2 from central in [default]
    commons-codec#commons-codec;1.15 from central in [default]
    commons-logging#commons-logging;1.2 from central in [default]
    dk.brics.automaton#automaton;1.11-8 from central in [default]
    io.grpc#grpc-alts;1.53.0 from central in [default]
    io.grpc#grpc-api;1.53.0 from central in [default]
    io.grpc#grpc-auth;1.53.0 from central in [default]
    io.grpc#grpc-context;1.53.0 from central in [default]
    io.grpc#grpc-core;1.53.0 from central in [default]
    io.grpc#grpc-googleapis;1.53.0 from central in [default]
    io.grpc#grpc-grpclb;1.53.0 from central in [default]
    io.grpc#grpc-netty-shaded;1.53.0 from central in [default]
    io.grpc#grpc-protobuf;1.53.0 from central in [default]
    io.grpc#grpc-protobuf-lite;1.53.0 from central in [default]
    io.grpc#grpc-services;1.53.0 from central in [default]
    io.grpc#grpc-stub;1.53.0 from central in [default]
    io.grpc#grpc-xds;1.53.0 from central in [default]
    io.opencensus#opencensus-api;0.31.1 from central in [default]
    io.opencensus#opencensus-contrib-http-util;0.31.1 from central in [default]
    io.opencensus#opencensus-proto;0.2.0 from central in [default]
    io.perfmark#perfmark-api;0.26.0 from central in [default]
    it.unimi.dsi#fastutil;7.0.12 from central in [default]
    javax.annotation#javax.annotation-api;1.3.2 from central in [default]
    org.apache.hadoop#hadoop-aws;3.2.2 from central in [default]
    org.checkerframework#checker-qual;3.31.0 from central in [default]
    org.codehaus.mojo#animal-sniffer-annotations;1.22 from central in [default]
    org.conscrypt#conscrypt-openjdk-uber;2.5.2 from central in [default]
    org.projectlombok#lombok;1.16.8 from central in [default]
    org.rocksdb#rocksdbjni;6.29.5 from central in [default]
    org.threeten#threetenbp;1.6.5 from central in [default]
    :: evicted modules:
    com.google.protobuf#protobuf-java-util;3.0.0-beta-3 by [com.google.protobuf#protobuf-java-util;3.21.12] in [default]
    com.google.protobuf#protobuf-java;3.0.0-beta-3 by [com.google.protobuf#protobuf-java;3.21.12] in [default]
    com.google.code.gson#gson;2.3 by [com.google.code.gson#gson;2.10.1] in [default]
    com.amazonaws#aws-java-sdk-bundle;1.11.563 by [com.amazonaws#aws-java-sdk-bundle;1.11.828] in [default]
    ---------------------------------------------------------------------
    |                  |            modules            ||   artifacts   |
    |       conf       | number| search|dwnlded|evicted|| number|dwnlded|
    ---------------------------------------------------------------------
    |      default     |   77  |   0   |   0   |   4   ||   73  |   0   |
    ---------------------------------------------------------------------
:: retrieving :: org.apache.spark#spark-submit-parent-a6a3ea83-f9a7-468f-9a5c-b7a8a4c19ea7
    confs: [default]
    0 artifacts copied, 73 already retrieved (0kB/277ms)
23/11/29 20:16:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Spark version: 3.4.0
sparknlp version: 5.1.3
## Read cleaned data from parquet

### Anime subreddits
import sagemaker
# session = sagemaker.Session()
# bucket = session.default_bucket()
bucket = 'sagemaker-us-east-1-315969085594'

com_bucket_path = f"s3a://{bucket}/project/cleaned/com"

print(f"reading comments from {com_bucket_path}")
com = spark.read.parquet(com_bucket_path, header=True)
print(f"shape of the com dataframe is {com.count():,}x{len(com.columns)}")
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
reading comments from s3a://sagemaker-us-east-1-315969085594/project/cleaned/com
23/11/29 20:16:21 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties
[Stage 1:======================================================>  (23 + 1) / 24]
shape of the com dataframe is 6,879,119x19
                                                                                
com.printSchema()
root
 |-- subreddit: string (nullable = true)
 |-- author: string (nullable = true)
 |-- author_flair_text: string (nullable = true)
 |-- created_utc: timestamp (nullable = true)
 |-- body: string (nullable = true)
 |-- controversiality: long (nullable = true)
 |-- score: long (nullable = true)
 |-- parent_id: string (nullable = true)
 |-- stickied: boolean (nullable = true)
 |-- link_id: string (nullable = true)
 |-- id: string (nullable = true)
 |-- created_date: string (nullable = true)
 |-- created_hour: integer (nullable = true)
 |-- created_week: integer (nullable = true)
 |-- created_month: integer (nullable = true)
 |-- created_year: integer (nullable = true)
 |-- cleaned: string (nullable = true)
 |-- body_wordCount: integer (nullable = true)
 |-- contain_pokemon: boolean (nullable = true)

Select features to be used

selected_features_com = com.select(F.col("controversiality").cast(IntegerType()),
                                   "score","stickied","body_wordCount",
                                   F.col("body").alias("text")
                                  )

Add a feature - sentiment analysis results from nlp

# Add sentiment category (negative,positive) as a feature
document = DocumentAssembler() \
.setInputCol("text") \
.setOutputCol("document")

cleanUpPatterns = ["[^a-zA-Z\s]+"] # ["[^\w\d\s]"] : remove punctuations (keep alphanumeric chars)

documentNormalizer = DocumentNormalizer() \
    .setInputCols("document") \
    .setOutputCol("normalizedDocument") \
    .setAction("clean") \
    .setPatterns(cleanUpPatterns) \
    .setReplacement(" ") \
    .setPolicy("pretty_all") \
    .setLowercase(True)

token = Tokenizer() \
.setInputCols(["normalizedDocument"]) \
.setOutputCol("token")

normalizer = Normalizer() \
.setInputCols(["token"]) \
.setOutputCol("normal")

vivekn =  ViveknSentimentModel.pretrained() \
.setInputCols(["document", "normal"]) \
.setOutputCol("result_sentiment")

finisher = Finisher() \
.setInputCols(["result_sentiment"]) \
.setOutputCols("final_sentiment")

sentiment_pipeline = Pipeline().setStages([document, documentNormalizer, token, normalizer, vivekn, finisher])
sentiment_vivekn download started this may take some time.
Approximate size to download 873.6 KB
[ / ]sentiment_vivekn download started this may take some time.
Approximate size to download 873.6 KB
Download done! Loading the resource.
[ — ]
                                                                                
[ \ ]
                                                                                
[ | ]
                                                                                
[OK!]
text_result = sentiment_pipeline.fit(selected_features_com).transform(selected_features_com)
text_result.show(5)
WARNING: An illegal reflective access operation has occurred
WARNING: Illegal reflective access by org.apache.spark.util.SizeEstimator$ (file:/opt/conda/lib/python3.10/site-packages/pyspark/jars/spark-core_2.12-3.4.0.jar) to field java.util.regex.Pattern.pattern
WARNING: Please consider reporting this to the maintainers of org.apache.spark.util.SizeEstimator$
WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations
WARNING: All illegal access operations will be denied in a future release
[Stage 8:>                                                          (0 + 1) / 1]
+----------------+-----+--------+--------------+--------------------+---------------+
|controversiality|score|stickied|body_wordCount|                text|final_sentiment|
+----------------+-----+--------+--------------+--------------------+---------------+
|               0|    1|   false|             6|  i sent it to ya ;)|     [negative]|
|               0|    1|   false|            16|displate has some...|     [positive]|
|               0|    3|   false|             6|that sounds like ...|     [positive]|
|               0|    1|   false|            10|what kind of ques...|     [negative]|
|               0|    4|   false|            28|today on shokugek...|     [positive]|
+----------------+-----+--------+--------------+--------------------+---------------+
only showing top 5 rows
                                                                                

Make a balanced dataset in terms of the target variable (controversiality)

selected_features_com_0_sample = text_result.filter(F.col("controversiality")==0).sample(fraction=110897/6768222,withReplacement=False)
selected_features_com_1 = text_result.filter(F.col("controversiality")==1)

balanced_selected_features_com = selected_features_com_0_sample.union(selected_features_com_1).cache()
balanced_selected_features_com.groupby("controversiality").count().show()
[Stage 9:=======================================================> (47 + 1) / 48]
+----------------+------+
|controversiality| count|
+----------------+------+
|               0|110433|
|               1|110897|
+----------------+------+
                                                                                
balanced_selected_features_com = balanced_selected_features_com.withColumn("final_sentiment",F.col("final_sentiment")[0]).drop("text")
balanced_selected_features_com.printSchema()
root
 |-- controversiality: integer (nullable = true)
 |-- score: long (nullable = true)
 |-- stickied: boolean (nullable = true)
 |-- body_wordCount: integer (nullable = true)
 |-- final_sentiment: string (nullable = true)
balanced_selected_features_com.show(3)
[Stage 9:>                                                          (0 + 1) / 1]
+----------------+-----+--------+--------------+---------------+
|controversiality|score|stickied|body_wordCount|final_sentiment|
+----------------+-----+--------+--------------+---------------+
|               0|    2|   false|             1|             na|
|               0|    2|   false|             5|       positive|
|               0|    2|   false|            65|       negative|
+----------------+-----+--------+--------------+---------------+
only showing top 3 rows
                                                                                
balanced_selected_features_com=balanced_selected_features_com.na.drop(how='any')

Machine Learning Pipeline

NaiveBayes Classification

stringIndexer_sentiment = StringIndexer(inputCol="final_sentiment", outputCol="sentiment_ind")

# MinMaxScaler for numeric variables: "sentiment_score","score","body_wordCount"
columns_to_scale = ["score","body_wordCount"]
numeric_var_assemblers = [VectorAssembler(inputCols=[col], outputCol=col + "_vec") for col in columns_to_scale]
scalers = [MinMaxScaler(inputCol=col + "_vec", outputCol=col + "_scaled") for col in columns_to_scale]

# Assemble all features
all_features = VectorAssembler(
    inputCols=["stickied", "sentiment_ind"]+[col+'_scaled' for col in columns_to_scale],
    outputCol= "all_features")

model_nb = NaiveBayes(modelType="multinomial", featuresCol="all_features", labelCol="controversiality")


pipeline_model_nb = Pipeline(stages=[stringIndexer_sentiment]+
                          numeric_var_assemblers+scalers+
                          [all_features,model_nb])
# balanced_selected_features_com_sample = balanced_selected_features_com.sample(fraction=0.005, withReplacement=False).cache()
(trainingData, testData) = balanced_selected_features_com.randomSplit([0.7, 0.3])
trained_model = pipeline_model_nb.fit(trainingData)
                                                                                
predictions = trained_model.transform(testData)
predictions.select("controversiality","probability","prediction").show(5)
+----------------+--------------------+----------+
|controversiality|         probability|prediction|
+----------------+--------------------+----------+
|               0|[0.49743486047435...|       1.0|
|               0|[0.50160588758410...|       0.0|
|               0|[0.50241503292721...|       0.0|
|               0|[0.49677557064141...|       1.0|
|               0|[0.49666228697880...|       1.0|
+----------------+--------------------+----------+
only showing top 5 rows

Predictions Evaluation

ROC

evaluator = BinaryClassificationEvaluator(labelCol="controversiality", rawPredictionCol="prediction", metricName="areaUnderROC")
roc_result = evaluator.evaluate(predictions)
roc_result
                                                                                
0.5203000583352344
label_preds = predictions.select("controversiality","probability","prediction").toPandas()
# Example labels and predicted probabilities (replace these with your actual data)
labels = label_preds.loc[:,"controversiality"]  # True labels (0s and 1s)
predicted_probs = np.array([i[1] for i in label_preds.loc[:,"probability"]])  # Predicted probabilities (or scores)

# Compute ROC curve
fpr, tpr, thresholds = roc_curve(labels, predicted_probs)
roc_auc = auc(fpr, tpr)

# Plot ROC curve
plt.figure(figsize=(7, 6))
plt.plot(fpr, tpr, color='#42a1b9', lw=2, label=f'AUC = {roc_auc:.2f}')
plt.plot([0, 1], [0, 1], color='#d13a47', lw=1.5, linestyle='--')
plt.xlabel('False Positive Rate',fontsize=12)
plt.ylabel('True Positive Rate',fontsize=12)
plt.title('ROC Curve for NaiveBayes Classification',fontsize=14)
plt.legend(loc='lower right')
plt.savefig("../../website-source/images/ROC_NB.png")
plt.show()
                                                                                

Accuracy and Confusion Matrix

# Calculate the elements of the confusion matrix
TN = predictions.filter('prediction = 0 AND controversiality = prediction').count()
TP = predictions.filter('prediction = 1 AND controversiality = prediction').count()
FN = predictions.filter('prediction = 0 AND controversiality = 1').count()
FP = predictions.filter('prediction = 1 AND controversiality = 0').count()

# Accuracy measures the proportion of correct predictions
accuracy = (TN + TP) / (TN + TP + FN + FP)
print(accuracy)
[Stage 1052:========================================>             (36 + 2) / 48]
0.5199542870891101
                                                                                
predicted_values = np.array([i for i in label_preds.loc[:,"prediction"]])
# Compute confusion matrix
cm = confusion_matrix(labels, predicted_values)

# Plot confusion matrix using seaborn heatmap
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap=sns.cubehelix_palette(as_cmap=True), annot_kws={"size": 14})
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix for NaiveBayes Classification')
plt.savefig("../../website-source/images/cm_NB.png")
plt.show()

Decision Tree Classification

model_dt = DecisionTreeClassifier(featuresCol="all_features", labelCol="controversiality")


pipeline_model_dt = Pipeline(stages=[stringIndexer_sentiment]+
                          numeric_var_assemblers+scalers+
                          [all_features,model_dt])
trained_model = pipeline_model_dt.fit(trainingData)
                                                                                
predictions2 = trained_model.transform(testData)
predictions2.select("controversiality","probability","prediction").show(5)
+----------------+--------------------+----------+
|controversiality|         probability|prediction|
+----------------+--------------------+----------+
|               0|[0.15181268882175...|       1.0|
|               0|[0.15181268882175...|       1.0|
|               0|[0.15181268882175...|       1.0|
|               0|[0.15181268882175...|       1.0|
|               0|[0.15181268882175...|       1.0|
+----------------+--------------------+----------+
only showing top 5 rows

Predictions Evaluation

evaluator2 = BinaryClassificationEvaluator(labelCol="controversiality", rawPredictionCol="prediction", metricName="areaUnderROC")
roc_result2 = evaluator.evaluate(predictions2)
roc_result2
                                                                                
0.7345819765253864
label_preds2 = predictions2.select("controversiality","probability","prediction").toPandas()
# Example labels and predicted probabilities (replace these with your actual data)
labels = label_preds2.loc[:,"controversiality"]  # True labels (0s and 1s)
predicted_probs = np.array([i[1] for i in label_preds2.loc[:,"probability"]])  # Predicted probabilities (or scores)

# Compute ROC curve
fpr, tpr, thresholds = roc_curve(labels, predicted_probs)
roc_auc = auc(fpr, tpr)

# Plot ROC curve
plt.figure(figsize=(7, 6))
plt.plot(fpr, tpr, color='#42a1b9', lw=2, label=f'AUC = {roc_auc:.2f}')
plt.plot([0, 1], [0, 1], color='#d13a47', lw=1.5, linestyle='--')
plt.xlabel('False Positive Rate',fontsize=12)
plt.ylabel('True Positive Rate',fontsize=12)
plt.title('ROC Curve for Decision Tree Classification',fontsize=14)
plt.legend(loc='lower right')
plt.savefig("../../website-source/images/ROC_DT.png")
plt.show()
                                                                                

Accuracy and Confusion Matrix

# Calculate the elements of the confusion matrix
TN = predictions2.filter('prediction = 0 AND controversiality = prediction').count()
TP = predictions2.filter('prediction = 1 AND controversiality = prediction').count()
FN = predictions2.filter('prediction = 0 AND controversiality = 1').count()
FP = predictions2.filter('prediction = 1 AND controversiality = 0').count()

# Accuracy measures the proportion of correct predictions
accuracy = (TN + TP) / (TN + TP + FN + FP)
print(accuracy)
[Stage 1065:===============================================>      (42 + 2) / 48]
0.7355267510751556
                                                                                
predicted_values = np.array([i for i in label_preds2.loc[:,"prediction"]])
# Compute confusion matrix
cm = confusion_matrix(labels, predicted_values)

# Plot confusion matrix using seaborn heatmap
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap=sns.cubehelix_palette(as_cmap=True), annot_kws={"size": 14})
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix for Decision Tree Classification')
plt.savefig("../../website-source/images/cm_DT.png")
plt.show()

Hyperparameter Tuning for Maximum Depth - maxDepth

def dt_model(maxDep):
    model_dt_tmp = DecisionTreeClassifier(featuresCol="all_features", labelCol="controversiality", maxDepth=maxDep)

    pipeline_model_dt_tmp = Pipeline(stages=[stringIndexer_sentiment]+
                              numeric_var_assemblers+scalers+
                              [all_features,model_dt])
    
    trained_model_tmp = pipeline_model_dt_tmp.fit(trainingData)
    predictions_tmp = trained_model_tmp.transform(testData)

    # Calculate the elements of the confusion matrix
    TN = predictions2.filter('prediction = 0 AND controversiality = prediction').count()
    TP = predictions2.filter('prediction = 1 AND controversiality = prediction').count()
    FN = predictions2.filter('prediction = 0 AND controversiality = 1').count()
    FP = predictions2.filter('prediction = 1 AND controversiality = 0').count()

    # Accuracy measures the proportion of correct predictions
    accuracy_tmp = (TN + TP) / (TN + TP + FN + FP)
    return accuracy
values_list = [i for i in range(4, 20)]
acc = [dt_model(i) for i in values_list]
                                                                                
plt.figure(figsize=(7, 6))
plt.plot(values_list, acc, color='#42a1b9', lw=2)
plt.xlabel('Maximum Depth',fontsize=12)
plt.ylabel('Accuracy',fontsize=12)
plt.title('Hyperparameter Tuning for Maximum Depth of Decision Tree',fontsize=14)
# plt.savefig("../../website-source/images/maxDepth_DT.png")
plt.show()