MACHINE LEARNING¶

SETUP¶

In [ ]:
# Setup - Run only once per Kernel App
%conda install openjdk -y

# install PySpark
%pip install pyspark==3.4.0

# install spark-nlp
%pip install spark-nlp==5.1.3

# restart kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")
Retrieving notices: ...working... done
Collecting package metadata (current_repodata.json): done
Solving environment: done


==> WARNING: A newer version of conda exists. <==
  current version: 23.3.1
  latest version: 23.10.0

Please update conda by running

    $ conda update -n base -c defaults conda

Or to minimize the number of packages updated during conda update use

     conda install conda=23.10.0



## Package Plan ##

  environment location: /opt/conda

  added / updated specs:
    - openjdk


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    ca-certificates-2023.08.22 |       h06a4308_0         123 KB
    certifi-2023.11.17         |  py310h06a4308_0         158 KB
    openjdk-11.0.13            |       h87a67e3_0       341.0 MB
    ------------------------------------------------------------
                                           Total:       341.3 MB

The following NEW packages will be INSTALLED:

  openjdk            pkgs/main/linux-64::openjdk-11.0.13-h87a67e3_0 

The following packages will be UPDATED:

  ca-certificates    conda-forge::ca-certificates-2023.7.2~ --> pkgs/main::ca-certificates-2023.08.22-h06a4308_0 
  certifi            conda-forge/noarch::certifi-2023.7.22~ --> pkgs/main/linux-64::certifi-2023.11.17-py310h06a4308_0 



Downloading and Extracting Packages
certifi-2023.11.17   | 158 KB    |                                       |   0% 
ca-certificates-2023 | 123 KB    |                                       |   0% 

openjdk-11.0.13      | 341.0 MB  |                                       |   0% 
certifi-2023.11.17   | 158 KB    | ##################################### | 100% 

openjdk-11.0.13      | 341.0 MB  | 3                                     |   1% 

openjdk-11.0.13      | 341.0 MB  | #8                                    |   5% 

openjdk-11.0.13      | 341.0 MB  | ###5                                  |  10% 

openjdk-11.0.13      | 341.0 MB  | ####9                                 |  13% 

openjdk-11.0.13      | 341.0 MB  | ######8                               |  18% 

openjdk-11.0.13      | 341.0 MB  | ########7                             |  24% 

openjdk-11.0.13      | 341.0 MB  | ##########7                           |  29% 

openjdk-11.0.13      | 341.0 MB  | ############6                         |  34% 

openjdk-11.0.13      | 341.0 MB  | ##############5                       |  39% 

openjdk-11.0.13      | 341.0 MB  | ################3                     |  44% 

openjdk-11.0.13      | 341.0 MB  | ##################2                   |  49% 

openjdk-11.0.13      | 341.0 MB  | ####################2                 |  55% 

openjdk-11.0.13      | 341.0 MB  | ######################2               |  60% 

openjdk-11.0.13      | 341.0 MB  | ########################1             |  65% 

openjdk-11.0.13      | 341.0 MB  | ##########################            |  70% 

openjdk-11.0.13      | 341.0 MB  | ############################          |  76% 

openjdk-11.0.13      | 341.0 MB  | #############################9        |  81% 

openjdk-11.0.13      | 341.0 MB  | ###############################8      |  86% 

openjdk-11.0.13      | 341.0 MB  | #################################7    |  91% 

openjdk-11.0.13      | 341.0 MB  | ###################################6  |  96% 

                                                                                
                                                                                

                                                                                
Preparing transaction: done
Verifying transaction: done
Executing transaction: done

Note: you may need to restart the kernel to use updated packages.
Collecting pyspark==3.4.0
  Using cached pyspark-3.4.0-py2.py3-none-any.whl
Collecting py4j==0.10.9.7 (from pyspark==3.4.0)
  Using cached py4j-0.10.9.7-py2.py3-none-any.whl (200 kB)
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9.7 pyspark-3.4.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[notice] A new release of pip is available: 23.2.1 -> 23.3.1
[notice] To update, run: pip install --upgrade pip
Note: you may need to restart the kernel to use updated packages.
Collecting spark-nlp==5.1.3
  Obtaining dependency information for spark-nlp==5.1.3 from https://files.pythonhosted.org/packages/cd/7d/bc0eca4c9ec4c9c1d9b28c42c2f07942af70980a7d912d0aceebf8db32dd/spark_nlp-5.1.3-py2.py3-none-any.whl.metadata
  Using cached spark_nlp-5.1.3-py2.py3-none-any.whl.metadata (53 kB)
Using cached spark_nlp-5.1.3-py2.py3-none-any.whl (537 kB)
Installing collected packages: spark-nlp
Successfully installed spark-nlp-5.1.3
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[notice] A new release of pip is available: 23.2.1 -> 23.3.1
[notice] To update, run: pip install --upgrade pip
Note: you may need to restart the kernel to use updated packages.
Out[ ]:

CLEANING¶

In [ ]:
# Import pyspark and build Spark session
from pyspark.sql import SparkSession

# Import pyspark and build Spark session
spark = SparkSession.builder \
    .appName("Spark NLP")\
    .master("local[*]")\
    .config("spark.driver.memory","16G")\
    .config("spark.executor.memory", "12g")\
    .config("spark.executor.cores", "3")\
    .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.1.3,org.apache.hadoop:hadoop-aws:3.2.2")\
    .config(
            "fs.s3a.aws.credentials.provider",
            "com.amazonaws.auth.ContainerCredentialsProvider"
    )\
    .getOrCreate()

print(spark.version)
Warning: Ignoring non-Spark config property: fs.s3a.aws.credentials.provider
:: loading settings :: url = jar:file:/opt/conda/lib/python3.10/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml
Ivy Default Cache set to: /root/.ivy2/cache
The jars for the packages stored in: /root/.ivy2/jars
com.johnsnowlabs.nlp#spark-nlp_2.12 added as a dependency
org.apache.hadoop#hadoop-aws added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-0a72fff1-6a0f-413c-a4ed-44cb433aeee1;1.0
	confs: [default]
	found com.johnsnowlabs.nlp#spark-nlp_2.12;5.1.3 in central
	found com.typesafe#config;1.4.2 in central
	found org.rocksdb#rocksdbjni;6.29.5 in central
	found com.amazonaws#aws-java-sdk-bundle;1.11.828 in central
	found com.github.universal-automata#liblevenshtein;3.0.0 in central
	found com.google.protobuf#protobuf-java-util;3.0.0-beta-3 in central
	found com.google.protobuf#protobuf-java;3.0.0-beta-3 in central
	found com.google.code.gson#gson;2.3 in central
	found it.unimi.dsi#fastutil;7.0.12 in central
	found org.projectlombok#lombok;1.16.8 in central
	found com.google.cloud#google-cloud-storage;2.20.1 in central
	found com.google.guava#guava;31.1-jre in central
	found com.google.guava#failureaccess;1.0.1 in central
	found com.google.guava#listenablefuture;9999.0-empty-to-avoid-conflict-with-guava in central
	found com.google.errorprone#error_prone_annotations;2.18.0 in central
	found com.google.j2objc#j2objc-annotations;1.3 in central
	found com.google.http-client#google-http-client;1.43.0 in central
	found io.opencensus#opencensus-contrib-http-util;0.31.1 in central
	found com.google.http-client#google-http-client-jackson2;1.43.0 in central
	found com.google.http-client#google-http-client-gson;1.43.0 in central
	found com.google.api-client#google-api-client;2.2.0 in central
	found commons-codec#commons-codec;1.15 in central
	found com.google.oauth-client#google-oauth-client;1.34.1 in central
	found com.google.http-client#google-http-client-apache-v2;1.43.0 in central
	found com.google.apis#google-api-services-storage;v1-rev20220705-2.0.0 in central
	found com.google.code.gson#gson;2.10.1 in central
	found com.google.cloud#google-cloud-core;2.12.0 in central
	found io.grpc#grpc-context;1.53.0 in central
	found com.google.auto.value#auto-value-annotations;1.10.1 in central
	found com.google.auto.value#auto-value;1.10.1 in central
	found javax.annotation#javax.annotation-api;1.3.2 in central
	found commons-logging#commons-logging;1.2 in central
	found com.google.cloud#google-cloud-core-http;2.12.0 in central
	found com.google.http-client#google-http-client-appengine;1.43.0 in central
	found com.google.api#gax-httpjson;0.108.2 in central
	found com.google.cloud#google-cloud-core-grpc;2.12.0 in central
	found io.grpc#grpc-alts;1.53.0 in central
	found io.grpc#grpc-grpclb;1.53.0 in central
	found org.conscrypt#conscrypt-openjdk-uber;2.5.2 in central
	found io.grpc#grpc-auth;1.53.0 in central
	found io.grpc#grpc-protobuf;1.53.0 in central
	found io.grpc#grpc-protobuf-lite;1.53.0 in central
	found io.grpc#grpc-core;1.53.0 in central
	found com.google.api#gax;2.23.2 in central
	found com.google.api#gax-grpc;2.23.2 in central
	found com.google.auth#google-auth-library-credentials;1.16.0 in central
	found com.google.auth#google-auth-library-oauth2-http;1.16.0 in central
	found com.google.api#api-common;2.6.2 in central
	found io.opencensus#opencensus-api;0.31.1 in central
	found com.google.api.grpc#proto-google-iam-v1;1.9.2 in central
	found com.google.protobuf#protobuf-java;3.21.12 in central
	found com.google.protobuf#protobuf-java-util;3.21.12 in central
	found com.google.api.grpc#proto-google-common-protos;2.14.2 in central
	found org.threeten#threetenbp;1.6.5 in central
	found com.google.api.grpc#proto-google-cloud-storage-v2;2.20.1-alpha in central
	found com.google.api.grpc#grpc-google-cloud-storage-v2;2.20.1-alpha in central
	found com.google.api.grpc#gapic-google-cloud-storage-v2;2.20.1-alpha in central
	found com.fasterxml.jackson.core#jackson-core;2.14.2 in central
	found com.google.code.findbugs#jsr305;3.0.2 in central
	found io.grpc#grpc-api;1.53.0 in central
	found io.grpc#grpc-stub;1.53.0 in central
	found org.checkerframework#checker-qual;3.31.0 in central
	found io.perfmark#perfmark-api;0.26.0 in central
	found com.google.android#annotations;4.1.1.4 in central
	found org.codehaus.mojo#animal-sniffer-annotations;1.22 in central
	found io.opencensus#opencensus-proto;0.2.0 in central
	found io.grpc#grpc-services;1.53.0 in central
	found com.google.re2j#re2j;1.6 in central
	found io.grpc#grpc-netty-shaded;1.53.0 in central
	found io.grpc#grpc-googleapis;1.53.0 in central
	found io.grpc#grpc-xds;1.53.0 in central
	found com.navigamez#greex;1.0 in central
	found dk.brics.automaton#automaton;1.11-8 in central
	found com.johnsnowlabs.nlp#tensorflow-cpu_2.12;0.4.4 in central
	found com.microsoft.onnxruntime#onnxruntime;1.15.0 in central
	found org.apache.hadoop#hadoop-aws;3.2.2 in central
:: resolution report :: resolve 3909ms :: artifacts dl 492ms
	:: modules in use:
	com.amazonaws#aws-java-sdk-bundle;1.11.828 from central in [default]
	com.fasterxml.jackson.core#jackson-core;2.14.2 from central in [default]
	com.github.universal-automata#liblevenshtein;3.0.0 from central in [default]
	com.google.android#annotations;4.1.1.4 from central in [default]
	com.google.api#api-common;2.6.2 from central in [default]
	com.google.api#gax;2.23.2 from central in [default]
	com.google.api#gax-grpc;2.23.2 from central in [default]
	com.google.api#gax-httpjson;0.108.2 from central in [default]
	com.google.api-client#google-api-client;2.2.0 from central in [default]
	com.google.api.grpc#gapic-google-cloud-storage-v2;2.20.1-alpha from central in [default]
	com.google.api.grpc#grpc-google-cloud-storage-v2;2.20.1-alpha from central in [default]
	com.google.api.grpc#proto-google-cloud-storage-v2;2.20.1-alpha from central in [default]
	com.google.api.grpc#proto-google-common-protos;2.14.2 from central in [default]
	com.google.api.grpc#proto-google-iam-v1;1.9.2 from central in [default]
	com.google.apis#google-api-services-storage;v1-rev20220705-2.0.0 from central in [default]
	com.google.auth#google-auth-library-credentials;1.16.0 from central in [default]
	com.google.auth#google-auth-library-oauth2-http;1.16.0 from central in [default]
	com.google.auto.value#auto-value;1.10.1 from central in [default]
	com.google.auto.value#auto-value-annotations;1.10.1 from central in [default]
	com.google.cloud#google-cloud-core;2.12.0 from central in [default]
	com.google.cloud#google-cloud-core-grpc;2.12.0 from central in [default]
	com.google.cloud#google-cloud-core-http;2.12.0 from central in [default]
	com.google.cloud#google-cloud-storage;2.20.1 from central in [default]
	com.google.code.findbugs#jsr305;3.0.2 from central in [default]
	com.google.code.gson#gson;2.10.1 from central in [default]
	com.google.errorprone#error_prone_annotations;2.18.0 from central in [default]
	com.google.guava#failureaccess;1.0.1 from central in [default]
	com.google.guava#guava;31.1-jre from central in [default]
	com.google.guava#listenablefuture;9999.0-empty-to-avoid-conflict-with-guava from central in [default]
	com.google.http-client#google-http-client;1.43.0 from central in [default]
	com.google.http-client#google-http-client-apache-v2;1.43.0 from central in [default]
	com.google.http-client#google-http-client-appengine;1.43.0 from central in [default]
	com.google.http-client#google-http-client-gson;1.43.0 from central in [default]
	com.google.http-client#google-http-client-jackson2;1.43.0 from central in [default]
	com.google.j2objc#j2objc-annotations;1.3 from central in [default]
	com.google.oauth-client#google-oauth-client;1.34.1 from central in [default]
	com.google.protobuf#protobuf-java;3.21.12 from central in [default]
	com.google.protobuf#protobuf-java-util;3.21.12 from central in [default]
	com.google.re2j#re2j;1.6 from central in [default]
	com.johnsnowlabs.nlp#spark-nlp_2.12;5.1.3 from central in [default]
	com.johnsnowlabs.nlp#tensorflow-cpu_2.12;0.4.4 from central in [default]
	com.microsoft.onnxruntime#onnxruntime;1.15.0 from central in [default]
	com.navigamez#greex;1.0 from central in [default]
	com.typesafe#config;1.4.2 from central in [default]
	commons-codec#commons-codec;1.15 from central in [default]
	commons-logging#commons-logging;1.2 from central in [default]
	dk.brics.automaton#automaton;1.11-8 from central in [default]
	io.grpc#grpc-alts;1.53.0 from central in [default]
	io.grpc#grpc-api;1.53.0 from central in [default]
	io.grpc#grpc-auth;1.53.0 from central in [default]
	io.grpc#grpc-context;1.53.0 from central in [default]
	io.grpc#grpc-core;1.53.0 from central in [default]
	io.grpc#grpc-googleapis;1.53.0 from central in [default]
	io.grpc#grpc-grpclb;1.53.0 from central in [default]
	io.grpc#grpc-netty-shaded;1.53.0 from central in [default]
	io.grpc#grpc-protobuf;1.53.0 from central in [default]
	io.grpc#grpc-protobuf-lite;1.53.0 from central in [default]
	io.grpc#grpc-services;1.53.0 from central in [default]
	io.grpc#grpc-stub;1.53.0 from central in [default]
	io.grpc#grpc-xds;1.53.0 from central in [default]
	io.opencensus#opencensus-api;0.31.1 from central in [default]
	io.opencensus#opencensus-contrib-http-util;0.31.1 from central in [default]
	io.opencensus#opencensus-proto;0.2.0 from central in [default]
	io.perfmark#perfmark-api;0.26.0 from central in [default]
	it.unimi.dsi#fastutil;7.0.12 from central in [default]
	javax.annotation#javax.annotation-api;1.3.2 from central in [default]
	org.apache.hadoop#hadoop-aws;3.2.2 from central in [default]
	org.checkerframework#checker-qual;3.31.0 from central in [default]
	org.codehaus.mojo#animal-sniffer-annotations;1.22 from central in [default]
	org.conscrypt#conscrypt-openjdk-uber;2.5.2 from central in [default]
	org.projectlombok#lombok;1.16.8 from central in [default]
	org.rocksdb#rocksdbjni;6.29.5 from central in [default]
	org.threeten#threetenbp;1.6.5 from central in [default]
	:: evicted modules:
	com.google.protobuf#protobuf-java-util;3.0.0-beta-3 by [com.google.protobuf#protobuf-java-util;3.21.12] in [default]
	com.google.protobuf#protobuf-java;3.0.0-beta-3 by [com.google.protobuf#protobuf-java;3.21.12] in [default]
	com.google.code.gson#gson;2.3 by [com.google.code.gson#gson;2.10.1] in [default]
	com.amazonaws#aws-java-sdk-bundle;1.11.563 by [com.amazonaws#aws-java-sdk-bundle;1.11.828] in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	---------------------------------------------------------------------
	|      default     |   77  |   0   |   0   |   4   ||   73  |   0   |
	---------------------------------------------------------------------
:: retrieving :: org.apache.spark#spark-submit-parent-0a72fff1-6a0f-413c-a4ed-44cb433aeee1
	confs: [default]
	0 artifacts copied, 73 already retrieved (0kB/160ms)
23/11/28 17:44:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
3.4.0
In [ ]:
import sagemaker
from pyspark.sql.functions import lower, regexp_replace, col, concat_ws
from pyspark.ml.feature import Tokenizer, StopWordsRemover
from sparknlp.annotator import *
from sparknlp.base import *
import sparknlp
from sparknlp.pretrained import PretrainedPipeline
from sparknlp.base import Finisher, DocumentAssembler
from pyspark.sql.functions import length

Import Data¶

In [ ]:
%%time
bucket = "project-group34"
session = sagemaker.Session()
output_prefix_data_submissions = "project/submissions/yyyy=*"
s3_path = f"s3a://{bucket}/{output_prefix_data_submissions}"
print(f"reading comments from {s3_path}")
submissions = spark.read.parquet(s3_path, header=True)
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
reading comments from s3a://project-group34/project/submissions/yyyy=*
23/11/27 21:29:41 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties
                                                                                
CPU times: user 199 ms, sys: 16.9 ms, total: 216 ms
Wall time: 6.49 s
23/11/27 21:29:47 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.

Data Processing¶

In [ ]:
print(f"Shape of the submissions data is: {submissions.count()} x {len(submissions.columns)}")
[Stage 1:=======================================================>(99 + 1) / 100]
Shape of the submissions data is: 875969 x 68
                                                                                
In [ ]:
submissions = submissions.select("subreddit", "title", "selftext", "score", "num_comments", "over_18", "is_self", "is_video", "domain", "created_utc", "author", "author_flair_text", "media")
In [ ]:
# Assuming your DataFrame is named `df`
submissions = submissions.withColumn('post_length', length(submissions.title) + length(submissions.selftext))
In [ ]:
from pyspark.sql import functions as F

submissions = submissions.withColumn('created_utc', F.to_timestamp('created_utc'))

# Extract time-based features
submissions = submissions.withColumn('hour_of_day', F.hour('created_utc'))
submissions = submissions.withColumn('day_of_week', F.dayofweek('created_utc'))  # 1 (Sunday) to 7 (Saturday)
# Map each day of the week from numeric to string
submissions = submissions.withColumn('day_of_week_str', F.expr("""
    CASE day_of_week 
        WHEN 1 THEN 'Sunday'
        WHEN 2 THEN 'Monday'
        WHEN 3 THEN 'Tuesday'
        WHEN 4 THEN 'Wednesday'
        WHEN 5 THEN 'Thursday'
        WHEN 6 THEN 'Friday'
        WHEN 7 THEN 'Saturday'
    END
"""))
submissions = submissions.withColumn('day_of_month', F.dayofmonth('created_utc'))
submissions = submissions.withColumn('month', F.month('created_utc'))
submissions = submissions.withColumn('year', F.year('created_utc'))

submissions = submissions.withColumn('has_media', F.col('media').isNotNull())
In [ ]:
submissions = submissions.drop(*["media"])
In [ ]:
submissions = submissions.select('subreddit',
                                 'title',
                                 'selftext',
                                 'score',
                                 'num_comments',
                                 'over_18',
                                 'is_self',
                                 'is_video',
                                 'domain',
                                 'post_length',
                                 'hour_of_day',
                                 'day_of_week',
                                 'day_of_week_str',
                                 'day_of_month',
                                 'month',
                                 'year',
                                 'has_media')
In [ ]:
# Combine 'title' and 'selftext' into a new column 'body'
submissions = submissions.withColumn("body", concat_ws(" ", col("title"), col("selftext")))
In [ ]:
submissions = submissions.drop(*["title", "selftext"])
In [ ]:
submissions.show(5)
[Stage 1:>                                                          (0 + 1) / 1]
+----------+-----+------------+-------+-------+--------+---------------+-----------+-----------+-----------+---------------+------------+-----+----+---------+--------------------+
| subreddit|score|num_comments|over_18|is_self|is_video|         domain|post_length|hour_of_day|day_of_week|day_of_week_str|day_of_month|month|year|has_media|                body|
+----------+-----+------------+-------+-------+--------+---------------+-----------+-----------+-----------+---------------+------------+-----+----+---------+--------------------+
|television|    0|           9|  false|   true|   false|self.television|        605|         22|          4|      Wednesday|          27|    1|2021|    false|Is there a websit...|
|     anime|    0|           3|  false|  false|   false|      i.redd.it|         50|         22|          4|      Wednesday|          27|    1|2021|    false|Does anyone know ...|
|television|    4|          11|  false|  false|   false|   deadline.com|         86|         22|          4|      Wednesday|          27|    1|2021|    false|‘Doogie Kameāloha...|
|    movies|    0|           4|  false|   true|   false|    self.movies|         42|         22|          4|      Wednesday|          27|    1|2021|    false|4K movies on desk...|
|     anime|    0|           9|  false|   true|   false|     self.anime|         64|         22|          4|      Wednesday|          27|    1|2021|    false|Where can I buy a...|
+----------+-----+------------+-------+-------+--------+---------------+-----------+-----------+-----------+---------------+------------+-----+----+---------+--------------------+
only showing top 5 rows

                                                                                
In [ ]:
from pyspark.sql.functions import col, count, when

missing_vals = submissions.select([count(when(col(c).isNull(), c)).alias(c) for c in submissions.columns])
In [ ]:
missing_vals.show()
[Stage 11:=====================================================> (98 + 2) / 100]
+---------+-----+------------+-------+-------+--------+------+-----------+-----------+-----------+---------------+------------+-----+----+---------+----+
|subreddit|score|num_comments|over_18|is_self|is_video|domain|post_length|hour_of_day|day_of_week|day_of_week_str|day_of_month|month|year|has_media|body|
+---------+-----+------------+-------+-------+--------+------+-----------+-----------+-----------+---------------+------------+-----+----+---------+----+
|        0|    0|           0|      0|      0|       0|  8002|          0|          0|          0|              0|           0|    0|   0|        0|   0|
+---------+-----+------------+-------+-------+--------+------+-----------+-----------+-----------+---------------+------------+-----+----+---------+----+

                                                                                
In [ ]:
submissions = submissions.na.drop(subset=["domain"])
In [ ]:
from pyspark.sql.functions import lower, regexp_replace

submissions = submissions.withColumn("body", lower(col("body")))

# Remove newline characters
submissions = submissions.withColumn("body", regexp_replace(col("body"), "\n", " "))

# Remove punctuations
submissions = submissions.withColumn("body", regexp_replace(col("body"), "[^a-zA-Z0-9\s]", ""))
In [ ]:
from pyspark.ml.feature import Tokenizer, StopWordsRemover, HashingTF, IDF

# Tokenize text
tokenizer = Tokenizer(inputCol="body", outputCol="words")
tokenized_df = tokenizer.transform(submissions)

# Remove stop words
remover = StopWordsRemover(inputCol="words", outputCol="filtered_words")
df_no_stopwords = remover.transform(tokenized_df)

# Vectorize words
hashingTF = HashingTF(inputCol="filtered_words", outputCol="rawFeatures")
featurizedData = hashingTF.transform(df_no_stopwords)

# Optionally, use IDF to rescale the feature vectors
idf = IDF(inputCol="rawFeatures", outputCol="features")
rescaledData = idf.fit(featurizedData).transform(featurizedData)
                                                                                
In [ ]:
rescaledData = rescaledData.drop(*["day_of_week_str", "words", "filtered_words", "rawFeatures"])
In [ ]:
rescaledData.printSchema()
root
 |-- subreddit: string (nullable = true)
 |-- score: long (nullable = true)
 |-- num_comments: long (nullable = true)
 |-- over_18: boolean (nullable = true)
 |-- is_self: boolean (nullable = true)
 |-- is_video: boolean (nullable = true)
 |-- domain: string (nullable = true)
 |-- post_length: integer (nullable = true)
 |-- hour_of_day: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- day_of_month: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- year: integer (nullable = true)
 |-- has_media: boolean (nullable = false)
 |-- body: string (nullable = false)
 |-- features: vector (nullable = true)

In [ ]:
rescaledData.show()
23/11/27 21:34:10 WARN DAGScheduler: Broadcasting large task binary with size 4.1 MiB
+----------+-----+------------+-------+-------+--------+---------------+-----------+-----------+-----------+------------+-----+----+---------+--------------------+--------------------+
| subreddit|score|num_comments|over_18|is_self|is_video|         domain|post_length|hour_of_day|day_of_week|day_of_month|month|year|has_media|                body|            features|
+----------+-----+------------+-------+-------+--------+---------------+-----------+-----------+-----------+------------+-----+----+---------+--------------------+--------------------+
|television|    0|           9|  false|   true|   false|self.television|        605|         22|          4|          27|    1|2021|    false|is there a websit...|(262144,[1546,158...|
|     anime|    0|           3|  false|  false|   false|      i.redd.it|         50|         22|          4|          27|    1|2021|    false|does anyone know ...|(262144,[101370,1...|
|television|    4|          11|  false|  false|   false|   deadline.com|         86|         22|          4|          27|    1|2021|    false|doogie kameloha m...|(262144,[16384,73...|
|    movies|    0|           4|  false|   true|   false|    self.movies|         42|         22|          4|          27|    1|2021|    false|4k movies on desk...|(262144,[1206,202...|
|     anime|    0|           9|  false|   true|   false|     self.anime|         64|         22|          4|          27|    1|2021|    false|where can i buy a...|(262144,[53596,92...|
|     anime|    0|           7|  false|   true|   false|     self.anime|       1249|         22|          4|          27|    1|2021|    false|ever suddenly fin...|(262144,[2437,392...|
|    movies|    1|           0|  false|  false|   false|     apple.news|         64|         22|          4|          27|    1|2021|    false|cloris leachman m...|(262144,[27708,48...|
|television|    2|           0|  false|  false|   false|    variety.com|         91|         22|          4|          27|    1|2021|    false|netflix alan yang...|(262144,[24303,27...|
|     anime|    1|           6|  false|  false|   false|       youtu.be|         67|         22|          4|          27|    1|2021|     true|the way eren spea...|(262144,[51471,53...|
|television|    3|           2|  false|   true|   false|self.television|         75|         22|          4|          27|    1|2021|    false|whats the best wa...|(262144,[8145,271...|
|     anime|    1|           0|  false|  false|   false|       youtu.be|         31|         22|          4|          27|    1|2021|     true|anime mix amv mmm...|(262144,[7231,820...|
|     anime|    1|           1|  false|  false|   false|       youtu.be|         28|         22|          4|          27|    1|2021|     true|eromanga sensei  ...|(262144,[8408,112...|
|     anime|    1|           1|  false|  false|   false|       youtu.be|         45|         22|          4|          27|    1|2021|     true|god of destructio...|(262144,[35501,55...|
|     anime|    1|           1|  false|   true|   false|     self.anime|         25|         22|          4|          27|    1|2021|    false|manga collectors ...|(262144,[41314,61...|
|     anime|    1|           1|  false|   true|   false|     self.anime|         29|         22|          4|          27|    1|2021|    false|wholesome animes ...|(262144,[49869,61...|
|    movies|    1|           0|  false|   true|   false|    self.movies|         64|         22|          4|          27|    1|2021|    false|a crime in a buil...|(262144,[33803,61...|
|     anime|    1|           3|  false|   true|   false|     self.anime|         38|         22|          4|          27|    1|2021|    false|question for mang...|(262144,[41314,61...|
|    movies|    1|          17|  false|   true|   false|    self.movies|         58|         22|          4|          27|    1|2021|    false|daredevil or ghos...|(262144,[61710,13...|
|     anime|    1|           5|  false|   true|   false|     self.anime|         69|         22|          4|          27|    1|2021|    false|idk what the tita...|(262144,[61710,79...|
|    movies|   15|          54|  false|   true|   false|    self.movies|        595|         22|          4|          27|    1|2021|    false|with hollywood ma...|(262144,[4978,853...|
+----------+-----+------------+-------+-------+--------+---------------+-----------+-----------+-----------+------------+-----+----+---------+--------------------+--------------------+
only showing top 20 rows

                                                                                
In [ ]:
rescaledData.select("features").show(1, truncate=False)
23/11/27 21:34:11 WARN DAGScheduler: Broadcasting large task binary with size 4.1 MiB
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|features                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|(262144,[1546,15867,27139,51471,52351,53570,54245,60776,61756,64358,91878,94214,111767,113004,116485,138201,140461,141589,143918,146337,156724,158661,161756,168211,179943,183553,190256,206312,206910,208258,213767,229604,231350,232735,234233,243658,245599,245731,249180,249943,254072,254600],[5.69251854502339,5.760023411747378,7.579618575735043,3.5506848158879545,4.968082016336675,13.530412171330834,5.253448124146994,6.604035998144887,4.909856857255696,4.2268389298045514,6.532216203833996,4.525126060326382,3.6691755070510474,4.716142116632298,6.872627092131839,11.461674818651053,12.80489709590626,7.257177844091133,5.128712738777624,8.83762821965198,7.575835844437219,5.314775239845497,9.66657694137099,5.447604138587951,12.287615765483569,7.473400952560769,10.178630837871902,4.128098018339647,5.057958163163958,2.268569872313168,4.6447319837013845,5.203915567000776,11.594468584923623,5.498081117888862,8.820816413026089,3.522859618212378,3.808643786887529,3.026130860184214,29.77500852413356,5.293682790260379,9.722666408022032,7.643224866342196])|
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
only showing top 1 row

In [ ]:
from pyspark.sql.functions import col

rescaledData = rescaledData.withColumn("over_18", col("over_18").cast("string"))
rescaledData = rescaledData.withColumn("is_self", col("is_self").cast("string"))
rescaledData = rescaledData.withColumn("is_video", col("is_video").cast("string"))
rescaledData = rescaledData.withColumn("has_media", col("has_media").cast("string"))
In [ ]:
rescaledData.write.parquet("s3a://project-group34/project/submissions/cleaned_ML/", mode="overwrite")
23/11/27 21:34:13 WARN DAGScheduler: Broadcasting large task binary with size 4.3 MiB
                                                                                

ML¶

SPARK JOB¶

In [ ]:
%%time
import time
import sagemaker
from sagemaker.spark.processing import PySparkProcessor

# Setup the PySpark processor to run the job. Note the instance type and instance count parameters. SageMaker will create these many instances of this type for the spark job.
role = sagemaker.get_execution_role()
spark_processor = PySparkProcessor(
    base_job_name="sm-spark-project",
    framework_version="3.3",
    role=role,
    instance_count=8,
    instance_type="ml.m5.xlarge",
    max_runtime_in_seconds=21600,
)

# # S3 URI of the initialization script
# s3_uri_init_script = f's3://{bucket}/{script_key}'

# s3 paths
session = sagemaker.Session()
output_prefix_logs = f"spark_logs"

configuration = [
    {
        "Classification": "spark-defaults",
        "Properties": {"spark.executor.memory": "12g", "spark.executor.cores": "4"},
    }
]
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
CPU times: user 1.37 s, sys: 539 ms, total: 1.91 s
Wall time: 1.47 s

RANDOM FOREST¶

PREPROCESSING JOB¶
In [ ]:
!mkdir -p ./code
In [ ]:
%%writefile ./code/preprocess_ml.py

import sys
import os
import logging
import argparse
# Import pyspark and build Spark session
from pyspark.sql import SparkSession
from pyspark.ml.feature import OneHotEncoder, StringIndexer, IndexToString, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, MultilayerPerceptronClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml import Pipeline, Model
import sagemaker
from pyspark.sql.functions import lower, regexp_replace, col, concat_ws
from pyspark.ml.feature import Tokenizer, StopWordsRemover
from pyspark.sql.functions import length


logging.basicConfig(format='%(asctime)s,%(levelname)s,%(module)s,%(filename)s,%(lineno)d,%(message)s', level=logging.DEBUG)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))

def main():
    
    parser = argparse.ArgumentParser(description="app inputs and outputs")
    parser.add_argument("--s3_dataset_path", type=str, help="Path of dataset in S3")    
    args = parser.parse_args()

    spark = SparkSession.builder \
    .appName("Spark ML")\
    .config("spark.driver.memory","16G")\
    .config("spark.driver.maxResultSize", "0") \
    .config("spark.kryoserializer.buffer.max", "2000M")\
    .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.1.3")\
    .getOrCreate()
    
    logger.info(f"Spark version: {spark.version}")
    
    # This is needed to save RDDs which is the only way to write nested Dataframes into CSV format
    sc = spark.sparkContext
    sc._jsc.hadoopConfiguration().set(
        "mapred.output.committer.class", "org.apache.hadoop.mapred.FileOutputCommitter"
    )

    # Downloading the data from S3 into a Dataframe
    logger.info(f"going to read {args.s3_dataset_path}")
    rescaledData = spark.read.parquet(args.s3_dataset_path, header=True)
    vals = ["movies", "anime"]
    rescaledData = rescaledData.where(col("subreddit").isin(vals))
    
    train_data, test_data, val_data = rescaledData.randomSplit([0.8, 0.15, 0.05], seed=1220)
    
    # Print the number of records in each dataset
    logger.info("Number of training records: " + str(train_data.count()))
    logger.info("Number of testing records: " + str(test_data.count()))
    logger.info("Number of validation records: " + str(val_data.count()))
    
    stringIndexer_over_18 = StringIndexer(inputCol="over_18", outputCol="over_18_ix")
    stringIndexer_is_self = StringIndexer(inputCol="is_self", outputCol="is_self_ix")
    stringIndexer_is_video = StringIndexer(inputCol="is_video", outputCol="is_video_ix")
    stringIndexer_has_media = StringIndexer(inputCol="has_media", outputCol="has_media_ix")
    stringIndexer_subreddit = StringIndexer(inputCol="subreddit", outputCol="subreddit_ix")
    onehot_over_18 = OneHotEncoder(inputCol="over_18_ix", outputCol="over_18_vec")
    onehot_is_self = OneHotEncoder(inputCol="is_self_ix", outputCol="is_self_vec")
    onehot_is_video = OneHotEncoder(inputCol="is_video_ix", outputCol="is_video_vec")
    onehot_has_media = OneHotEncoder(inputCol="has_media_ix", outputCol="has_media_vec")
    
    vectorAssembler_features = VectorAssembler(
                                    inputCols=["features", "over_18_vec", "is_self_vec", "is_video_vec", 
                                               "has_media_vec", "score", "num_comments",
                                               "post_length", "hour_of_day", "day_of_week", 
                                               "day_of_month", "month", "year"],
                                    outputCol="combined_features")

    # Define the stages for the pipeline
    stages = [
        stringIndexer_subreddit, 
        stringIndexer_over_18, 
        stringIndexer_is_self, 
        stringIndexer_is_video, 
        stringIndexer_has_media,
        onehot_over_18,
        onehot_is_self,
        onehot_is_video,
        onehot_has_media,
        vectorAssembler_features
    ]

    # Define the pipeline without the classifier and evaluator
    pipeline = Pipeline(stages=stages)

    # Fit the preprocessing part of the pipeline
    pipeline_fit = pipeline.fit(train_data)

    # Transform the data
    transformed_train_data = pipeline_fit.transform(train_data)
    transformed_test_data = pipeline_fit.transform(test_data)
    
    transformed_train_data.write.parquet("s3a://project-group34/project/submissions/preprocessed_ML/train/")
    transformed_test_data.write.parquet("s3a://project-group34/project/submissions/preprocessed_ML/test/")
    
if __name__ == "__main__":
    main()
Overwriting ./code/preprocess_ml.py
In [ ]:
%%time
bucket = "project-group34"
s3_path = "s3a://project-group34/project/submissions/cleaned_ML/"

# run the job now, the arguments array is provided as command line to the Python script (Spark code in this case).
spark_processor.run(
    submit_app="./code/preprocess_ml.py",
    arguments=[
        "--s3_dataset_path",
        s3_path,
    ],
    spark_event_logs_s3_uri="s3://{}/{}/spark_event_logs".format(bucket, output_prefix_logs),
    logs=False,
    configuration=configuration
)
INFO:sagemaker:Creating processing-job with name sm-spark-project-2023-11-28-04-04-56-000
.....................................................................................................!CPU times: user 359 ms, sys: 45.3 ms, total: 404 ms
Wall time: 8min 34s
HYPERPARAMETERS - DEFAULT¶
In [ ]:
%%writefile ./code/ml_train_RF1.py

import sys
import os
import logging
import argparse
# Import pyspark and build Spark session
from pyspark.sql import SparkSession
from pyspark.ml.feature import OneHotEncoder, StringIndexer, IndexToString, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, MultilayerPerceptronClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml import Pipeline, Model
import sagemaker
from pyspark.sql.functions import lower, regexp_replace, col, concat_ws
from pyspark.ml.feature import Tokenizer, StopWordsRemover
from pyspark.sql.functions import length


logging.basicConfig(format='%(asctime)s,%(levelname)s,%(module)s,%(filename)s,%(lineno)d,%(message)s', level=logging.DEBUG)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))

def main():

    spark = (SparkSession.builder\
             .appName("PySparkApp")\
             .config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.2.2")\
             .config(
                    "fs.s3a.aws.credentials.provider",
                    "com.amazonaws.auth.ContainerCredentialsProvider",
                )\
    .getOrCreate()
)
    
    logger.info(f"Spark version: {spark.version}")
    
    # This is needed to save RDDs which is the only way to write nested Dataframes into CSV format
    sc = spark.sparkContext
    sc._jsc.hadoopConfiguration().set(
        "mapred.output.committer.class", "org.apache.hadoop.mapred.FileOutputCommitter"
    )

    # Downloading the data from S3 into a Dataframe
    transformed_train_data = spark.read.parquet("s3a://project-group34/project/submissions/preprocessed_ML/train/")
    transformed_test_data = spark.read.parquet("s3a://project-group34/project/submissions/preprocessed_ML/test/")
    
    # RandomForestClassifier without hyperparameter tuning
    rf_classifier = RandomForestClassifier(labelCol="subreddit_ix", 
                                           featuresCol="combined_features")
    
    labelConverter = IndexToString(inputCol="prediction", 
                               outputCol="predictedSubreddit", 
                               labels=["anime", "movie"])

    # Add the best model and labelConverter to the pipeline
    pipeline = Pipeline(stages=[rf_classifier, 
                                labelConverter])
    
    # Fit the entire pipeline
    pipeline_fit = pipeline.fit(transformed_train_data)
    train_predictions = pipeline_fit.transform(transformed_train_data)

    # Transform the data with the best model
    test_predictions = pipeline_fit.transform(transformed_test_data)
    
    train_predictions.write.parquet("s3a://project-group34/project/submissions/RandomForest/default/train_pred/", mode="overwrite")    
    test_predictions.write.parquet("s3a://project-group34/project/submissions/RandomForest/default/test_pred/", mode="overwrite")
    pipeline_fit.save("s3a://project-group34/project/submissions/RandomForest/default/model/")
    
    logger.info(f"all done...")
    
if __name__ == "__main__":
    main()
Overwriting ./code/ml_train_RF1.py
In [ ]:
%%time
bucket = "project-group34"

# run the job now, the arguments array is provided as command line to the Python script (Spark code in this case).
spark_processor.run(
    submit_app="./code/ml_train_RF1.py",
    spark_event_logs_s3_uri="s3://{}/{}/spark_event_logs".format(bucket, output_prefix_logs),
    logs=False,
    configuration=configuration
)
INFO:sagemaker:Creating processing-job with name sm-spark-project-2023-11-30-23-59-44-449
..............................................................................................................................................................................................................................................................................!CPU times: user 1 s, sys: 83.5 ms, total: 1.08 s
Wall time: 22min 46s
HYPERPARAMETERS - numTrees=100, maxDepth=10, maxBins=64¶
In [ ]:
%%writefile ./code/ml_train_RF2.py

import sys
import os
import logging
import argparse
# Import pyspark and build Spark session
from pyspark.sql import SparkSession
from pyspark.ml.feature import OneHotEncoder, StringIndexer, IndexToString, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, MultilayerPerceptronClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml import Pipeline, Model
import sagemaker
from pyspark.sql.functions import lower, regexp_replace, col, concat_ws
from pyspark.ml.feature import Tokenizer, StopWordsRemover
from pyspark.sql.functions import length


logging.basicConfig(format='%(asctime)s,%(levelname)s,%(module)s,%(filename)s,%(lineno)d,%(message)s', level=logging.DEBUG)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))

def main():

    spark = (SparkSession.builder\
             .appName("PySparkApp")\
             .config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.2.2")\
             .config(
                    "fs.s3a.aws.credentials.provider",
                    "com.amazonaws.auth.ContainerCredentialsProvider",
                )\
    .getOrCreate()
)
    
    logger.info(f"Spark version: {spark.version}")
    
    # This is needed to save RDDs which is the only way to write nested Dataframes into CSV format
    sc = spark.sparkContext
    sc._jsc.hadoopConfiguration().set(
        "mapred.output.committer.class", "org.apache.hadoop.mapred.FileOutputCommitter"
    )

    # Downloading the data from S3 into a Dataframe
    transformed_train_data = spark.read.parquet("s3a://project-group34/project/submissions/preprocessed_ML/train/")
    transformed_test_data = spark.read.parquet("s3a://project-group34/project/submissions/preprocessed_ML/test/")
    
    # RandomForestClassifier without hyperparameter tuning
    rf_classifier = RandomForestClassifier(labelCol="subreddit_ix", 
                                           featuresCol="combined_features", 
                                           numTrees=100, 
                                           maxDepth=10,
                                           maxBins=64)
    labelConverter = IndexToString(inputCol="prediction", 
                               outputCol="predictedSubreddit", 
                               labels=["anime", "movie"])

    # Add the best model and labelConverter to the pipeline
    pipeline = Pipeline(stages=[rf_classifier, 
                                labelConverter])
    
    # Fit the entire pipeline
    pipeline_fit = pipeline.fit(transformed_train_data)
    train_predictions = pipeline_fit.transform(transformed_train_data)

    # Transform the data with the best model
    test_predictions = pipeline_fit.transform(transformed_test_data)
    
    train_predictions.write.parquet("s3a://project-group34/project/submissions/RandomForest/numTrees=100_maxDepth=10_maxBins=64/train_pred/", mode="overwrite")    
    test_predictions.write.parquet("s3a://project-group34/project/submissions/RandomForest/numTrees=100_maxDepth=10_maxBins=64/test_pred/", mode="overwrite")
    pipeline_fit.save("s3a://project-group34/project/submissions/RandomForest/numTrees=100_maxDepth=10_maxBins=64/model/")
    
    logger.info(f"all done...")
    
if __name__ == "__main__":
    main()
Overwriting ./code/ml_train_RF2.py
In [ ]:
%%time
bucket = "project-group34"

# run the job now, the arguments array is provided as command line to the Python script (Spark code in this case).
spark_processor.run(
    submit_app="./code/ml_train_RF2.py",
    spark_event_logs_s3_uri="s3://{}/{}/spark_event_logs".format(bucket, output_prefix_logs),
    logs=False,
    configuration=configuration
)
INFO:sagemaker:Creating processing-job with name sm-spark-project-2023-11-28-23-33-31-511
....................................................................................................................................................................................................................................................................................................................................................................................................................!CPU times: user 1.44 s, sys: 106 ms, total: 1.54 s
Wall time: 34min

EVALUATION¶

In [ ]:
# Import pyspark and build Spark session
from pyspark.sql import SparkSession

# Import pyspark and build Spark session
spark = SparkSession.builder \
    .appName("Spark NLP")\
    .master("local[*]")\
    .config("spark.driver.memory","16G")\
    .config("spark.executor.memory", "12g")\
    .config("spark.executor.cores", "3")\
    .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.1.3,org.apache.hadoop:hadoop-aws:3.2.2")\
    .config(
            "fs.s3a.aws.credentials.provider",
            "com.amazonaws.auth.ContainerCredentialsProvider"
    )\
    .getOrCreate()

print(spark.version)
Warning: Ignoring non-Spark config property: fs.s3a.aws.credentials.provider
:: loading settings :: url = jar:file:/opt/conda/lib/python3.10/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml
Ivy Default Cache set to: /root/.ivy2/cache
The jars for the packages stored in: /root/.ivy2/jars
com.johnsnowlabs.nlp#spark-nlp_2.12 added as a dependency
org.apache.hadoop#hadoop-aws added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-ebc720bb-2fd4-4933-9da9-6508d62eae59;1.0
	confs: [default]
	found com.johnsnowlabs.nlp#spark-nlp_2.12;5.1.3 in central
	found com.typesafe#config;1.4.2 in central
	found org.rocksdb#rocksdbjni;6.29.5 in central
	found com.amazonaws#aws-java-sdk-bundle;1.11.828 in central
	found com.github.universal-automata#liblevenshtein;3.0.0 in central
	found com.google.protobuf#protobuf-java-util;3.0.0-beta-3 in central
	found com.google.protobuf#protobuf-java;3.0.0-beta-3 in central
	found com.google.code.gson#gson;2.3 in central
	found it.unimi.dsi#fastutil;7.0.12 in central
	found org.projectlombok#lombok;1.16.8 in central
	found com.google.cloud#google-cloud-storage;2.20.1 in central
	found com.google.guava#guava;31.1-jre in central
	found com.google.guava#failureaccess;1.0.1 in central
	found com.google.guava#listenablefuture;9999.0-empty-to-avoid-conflict-with-guava in central
	found com.google.errorprone#error_prone_annotations;2.18.0 in central
	found com.google.j2objc#j2objc-annotations;1.3 in central
	found com.google.http-client#google-http-client;1.43.0 in central
	found io.opencensus#opencensus-contrib-http-util;0.31.1 in central
	found com.google.http-client#google-http-client-jackson2;1.43.0 in central
	found com.google.http-client#google-http-client-gson;1.43.0 in central
	found com.google.api-client#google-api-client;2.2.0 in central
	found commons-codec#commons-codec;1.15 in central
	found com.google.oauth-client#google-oauth-client;1.34.1 in central
	found com.google.http-client#google-http-client-apache-v2;1.43.0 in central
	found com.google.apis#google-api-services-storage;v1-rev20220705-2.0.0 in central
	found com.google.code.gson#gson;2.10.1 in central
	found com.google.cloud#google-cloud-core;2.12.0 in central
	found io.grpc#grpc-context;1.53.0 in central
	found com.google.auto.value#auto-value-annotations;1.10.1 in central
	found com.google.auto.value#auto-value;1.10.1 in central
	found javax.annotation#javax.annotation-api;1.3.2 in central
	found commons-logging#commons-logging;1.2 in central
	found com.google.cloud#google-cloud-core-http;2.12.0 in central
	found com.google.http-client#google-http-client-appengine;1.43.0 in central
	found com.google.api#gax-httpjson;0.108.2 in central
	found com.google.cloud#google-cloud-core-grpc;2.12.0 in central
	found io.grpc#grpc-alts;1.53.0 in central
	found io.grpc#grpc-grpclb;1.53.0 in central
	found org.conscrypt#conscrypt-openjdk-uber;2.5.2 in central
	found io.grpc#grpc-auth;1.53.0 in central
	found io.grpc#grpc-protobuf;1.53.0 in central
	found io.grpc#grpc-protobuf-lite;1.53.0 in central
	found io.grpc#grpc-core;1.53.0 in central
	found com.google.api#gax;2.23.2 in central
	found com.google.api#gax-grpc;2.23.2 in central
	found com.google.auth#google-auth-library-credentials;1.16.0 in central
	found com.google.auth#google-auth-library-oauth2-http;1.16.0 in central
	found com.google.api#api-common;2.6.2 in central
	found io.opencensus#opencensus-api;0.31.1 in central
	found com.google.api.grpc#proto-google-iam-v1;1.9.2 in central
	found com.google.protobuf#protobuf-java;3.21.12 in central
	found com.google.protobuf#protobuf-java-util;3.21.12 in central
	found com.google.api.grpc#proto-google-common-protos;2.14.2 in central
	found org.threeten#threetenbp;1.6.5 in central
	found com.google.api.grpc#proto-google-cloud-storage-v2;2.20.1-alpha in central
	found com.google.api.grpc#grpc-google-cloud-storage-v2;2.20.1-alpha in central
	found com.google.api.grpc#gapic-google-cloud-storage-v2;2.20.1-alpha in central
	found com.fasterxml.jackson.core#jackson-core;2.14.2 in central
	found com.google.code.findbugs#jsr305;3.0.2 in central
	found io.grpc#grpc-api;1.53.0 in central
	found io.grpc#grpc-stub;1.53.0 in central
	found org.checkerframework#checker-qual;3.31.0 in central
	found io.perfmark#perfmark-api;0.26.0 in central
	found com.google.android#annotations;4.1.1.4 in central
	found org.codehaus.mojo#animal-sniffer-annotations;1.22 in central
	found io.opencensus#opencensus-proto;0.2.0 in central
	found io.grpc#grpc-services;1.53.0 in central
	found com.google.re2j#re2j;1.6 in central
	found io.grpc#grpc-netty-shaded;1.53.0 in central
	found io.grpc#grpc-googleapis;1.53.0 in central
	found io.grpc#grpc-xds;1.53.0 in central
	found com.navigamez#greex;1.0 in central
	found dk.brics.automaton#automaton;1.11-8 in central
	found com.johnsnowlabs.nlp#tensorflow-cpu_2.12;0.4.4 in central
	found com.microsoft.onnxruntime#onnxruntime;1.15.0 in central
	found org.apache.hadoop#hadoop-aws;3.2.2 in central
:: resolution report :: resolve 4860ms :: artifacts dl 732ms
	:: modules in use:
	com.amazonaws#aws-java-sdk-bundle;1.11.828 from central in [default]
	com.fasterxml.jackson.core#jackson-core;2.14.2 from central in [default]
	com.github.universal-automata#liblevenshtein;3.0.0 from central in [default]
	com.google.android#annotations;4.1.1.4 from central in [default]
	com.google.api#api-common;2.6.2 from central in [default]
	com.google.api#gax;2.23.2 from central in [default]
	com.google.api#gax-grpc;2.23.2 from central in [default]
	com.google.api#gax-httpjson;0.108.2 from central in [default]
	com.google.api-client#google-api-client;2.2.0 from central in [default]
	com.google.api.grpc#gapic-google-cloud-storage-v2;2.20.1-alpha from central in [default]
	com.google.api.grpc#grpc-google-cloud-storage-v2;2.20.1-alpha from central in [default]
	com.google.api.grpc#proto-google-cloud-storage-v2;2.20.1-alpha from central in [default]
	com.google.api.grpc#proto-google-common-protos;2.14.2 from central in [default]
	com.google.api.grpc#proto-google-iam-v1;1.9.2 from central in [default]
	com.google.apis#google-api-services-storage;v1-rev20220705-2.0.0 from central in [default]
	com.google.auth#google-auth-library-credentials;1.16.0 from central in [default]
	com.google.auth#google-auth-library-oauth2-http;1.16.0 from central in [default]
	com.google.auto.value#auto-value;1.10.1 from central in [default]
	com.google.auto.value#auto-value-annotations;1.10.1 from central in [default]
	com.google.cloud#google-cloud-core;2.12.0 from central in [default]
	com.google.cloud#google-cloud-core-grpc;2.12.0 from central in [default]
	com.google.cloud#google-cloud-core-http;2.12.0 from central in [default]
	com.google.cloud#google-cloud-storage;2.20.1 from central in [default]
	com.google.code.findbugs#jsr305;3.0.2 from central in [default]
	com.google.code.gson#gson;2.10.1 from central in [default]
	com.google.errorprone#error_prone_annotations;2.18.0 from central in [default]
	com.google.guava#failureaccess;1.0.1 from central in [default]
	com.google.guava#guava;31.1-jre from central in [default]
	com.google.guava#listenablefuture;9999.0-empty-to-avoid-conflict-with-guava from central in [default]
	com.google.http-client#google-http-client;1.43.0 from central in [default]
	com.google.http-client#google-http-client-apache-v2;1.43.0 from central in [default]
	com.google.http-client#google-http-client-appengine;1.43.0 from central in [default]
	com.google.http-client#google-http-client-gson;1.43.0 from central in [default]
	com.google.http-client#google-http-client-jackson2;1.43.0 from central in [default]
	com.google.j2objc#j2objc-annotations;1.3 from central in [default]
	com.google.oauth-client#google-oauth-client;1.34.1 from central in [default]
	com.google.protobuf#protobuf-java;3.21.12 from central in [default]
	com.google.protobuf#protobuf-java-util;3.21.12 from central in [default]
	com.google.re2j#re2j;1.6 from central in [default]
	com.johnsnowlabs.nlp#spark-nlp_2.12;5.1.3 from central in [default]
	com.johnsnowlabs.nlp#tensorflow-cpu_2.12;0.4.4 from central in [default]
	com.microsoft.onnxruntime#onnxruntime;1.15.0 from central in [default]
	com.navigamez#greex;1.0 from central in [default]
	com.typesafe#config;1.4.2 from central in [default]
	commons-codec#commons-codec;1.15 from central in [default]
	commons-logging#commons-logging;1.2 from central in [default]
	dk.brics.automaton#automaton;1.11-8 from central in [default]
	io.grpc#grpc-alts;1.53.0 from central in [default]
	io.grpc#grpc-api;1.53.0 from central in [default]
	io.grpc#grpc-auth;1.53.0 from central in [default]
	io.grpc#grpc-context;1.53.0 from central in [default]
	io.grpc#grpc-core;1.53.0 from central in [default]
	io.grpc#grpc-googleapis;1.53.0 from central in [default]
	io.grpc#grpc-grpclb;1.53.0 from central in [default]
	io.grpc#grpc-netty-shaded;1.53.0 from central in [default]
	io.grpc#grpc-protobuf;1.53.0 from central in [default]
	io.grpc#grpc-protobuf-lite;1.53.0 from central in [default]
	io.grpc#grpc-services;1.53.0 from central in [default]
	io.grpc#grpc-stub;1.53.0 from central in [default]
	io.grpc#grpc-xds;1.53.0 from central in [default]
	io.opencensus#opencensus-api;0.31.1 from central in [default]
	io.opencensus#opencensus-contrib-http-util;0.31.1 from central in [default]
	io.opencensus#opencensus-proto;0.2.0 from central in [default]
	io.perfmark#perfmark-api;0.26.0 from central in [default]
	it.unimi.dsi#fastutil;7.0.12 from central in [default]
	javax.annotation#javax.annotation-api;1.3.2 from central in [default]
	org.apache.hadoop#hadoop-aws;3.2.2 from central in [default]
	org.checkerframework#checker-qual;3.31.0 from central in [default]
	org.codehaus.mojo#animal-sniffer-annotations;1.22 from central in [default]
	org.conscrypt#conscrypt-openjdk-uber;2.5.2 from central in [default]
	org.projectlombok#lombok;1.16.8 from central in [default]
	org.rocksdb#rocksdbjni;6.29.5 from central in [default]
	org.threeten#threetenbp;1.6.5 from central in [default]
	:: evicted modules:
	com.google.protobuf#protobuf-java-util;3.0.0-beta-3 by [com.google.protobuf#protobuf-java-util;3.21.12] in [default]
	com.google.protobuf#protobuf-java;3.0.0-beta-3 by [com.google.protobuf#protobuf-java;3.21.12] in [default]
	com.google.code.gson#gson;2.3 by [com.google.code.gson#gson;2.10.1] in [default]
	com.amazonaws#aws-java-sdk-bundle;1.11.563 by [com.amazonaws#aws-java-sdk-bundle;1.11.828] in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	---------------------------------------------------------------------
	|      default     |   77  |   0   |   0   |   4   ||   73  |   0   |
	---------------------------------------------------------------------
:: retrieving :: org.apache.spark#spark-submit-parent-ebc720bb-2fd4-4933-9da9-6508d62eae59
	confs: [default]
	0 artifacts copied, 73 already retrieved (0kB/236ms)
23/12/01 00:24:05 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
3.4.0
In [ ]:
import sagemaker
from pyspark.sql.functions import lower, regexp_replace, col, concat_ws
from pyspark.ml.feature import Tokenizer, StopWordsRemover
from sparknlp.annotator import *
from sparknlp.base import *
import sparknlp
from sparknlp.pretrained import PretrainedPipeline
from sparknlp.base import Finisher, DocumentAssembler
from pyspark.sql.functions import length
In [ ]:
from pyspark.ml.feature import OneHotEncoder, StringIndexer, IndexToString, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, MultilayerPerceptronClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml import Pipeline, Model
In [ ]:
transformed_train_RF = spark.read.parquet("s3a://project-group34/project/submissions/RandomForest/default/train_pred/")
transformed_train_RF_withHyperparams = spark.read.parquet("s3a://project-group34/project/submissions/RandomForest/numTrees=100_maxDepth=10_maxBins=64/train_pred/")
23/12/01 00:26:18 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties
                                                                                
In [ ]:
transformed_test_RF = spark.read.parquet("s3a://project-group34/project/submissions/RandomForest/default/test_pred/")
transformed_test_RF_withHyperparams = spark.read.parquet("s3a://project-group34/project/submissions/RandomForest/numTrees=100_maxDepth=10_maxBins=64/test_pred/")
                                                                                
In [ ]:
evaluator = MulticlassClassificationEvaluator(labelCol="subreddit_ix", predictionCol="prediction", metricName="accuracy")
accuracy_RF_train = evaluator.evaluate(transformed_train_RF)
accuracy_RF_test = evaluator.evaluate(transformed_test_RF)
accuracy_RF_withHyperparams_train = evaluator.evaluate(transformed_train_RF_withHyperparams)
accuracy_RF_withHyperparams_test = evaluator.evaluate(transformed_test_RF_withHyperparams)
                                                                                
In [ ]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
import pandas as pd

# Already defined in your script
# evaluator = MulticlassClassificationEvaluator(labelCol="subreddit_ix", predictionCol="prediction")

# Function to calculate metrics
def compute_metrics(dataset, label_col, prediction_col):
    evaluator = MulticlassClassificationEvaluator(labelCol=label_col, predictionCol=prediction_col)
    accuracy = evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
    f1 = evaluator.evaluate(dataset, {evaluator.metricName: "f1"})
    precision = evaluator.evaluate(dataset, {evaluator.metricName: "weightedPrecision"})
    recall = evaluator.evaluate(dataset, {evaluator.metricName: "weightedRecall"})
    return accuracy, f1, precision, recall

# Calculate metrics for the default model
accuracy_RF_train, f1_RF_train, precision_RF_train, recall_RF_train = compute_metrics(transformed_train_RF, "subreddit_ix", "prediction")
accuracy_RF_test, f1_RF_test, precision_RF_test, recall_RF_test = compute_metrics(transformed_test_RF, "subreddit_ix", "prediction")

# Calculate metrics for the model with hyperparameters
accuracy_RF_hp_train, f1_RF_hp_train, precision_RF_hp_train, recall_RF_hp_train = compute_metrics(transformed_train_RF_withHyperparams, "subreddit_ix", "prediction")
accuracy_RF_hp_test, f1_RF_hp_test, precision_RF_hp_test, recall_RF_hp_test = compute_metrics(transformed_test_RF_withHyperparams, "subreddit_ix", "prediction")
                                                                                
In [ ]:
# Create a DataFrame with the metrics
metrics_data = {
    'Model': ['RandomForest Default', 'RandomForest Hyperparameters'],
    'Accuracy Train': [accuracy_RF_train, accuracy_RF_hp_train],
    'Accuracy Test': [accuracy_RF_test, accuracy_RF_hp_test],
    'F1 Score Train': [f1_RF_train, f1_RF_hp_train],
    'F1 Score Test': [f1_RF_test, f1_RF_hp_test],
    'Precision Train': [precision_RF_train, precision_RF_hp_train],
    'Precision Test': [precision_RF_test, precision_RF_hp_test],
    'Recall Train': [recall_RF_train, recall_RF_hp_train],
    'Recall Test': [recall_RF_test, recall_RF_hp_test]
}

metrics_df = pd.DataFrame(metrics_data)
In [ ]:
metrics_df
Out[ ]:
Model Accuracy Train Accuracy Test F1 Score Train F1 Score Test Precision Train Precision Test Recall Train Recall Test
0 RandomForest Default 0.676122 0.675532 0.642950 0.642239 0.760997 0.760339 0.676122 0.675532
1 RandomForest Hyperparameters 0.760637 0.759841 0.745298 0.744455 0.829303 0.828266 0.760637 0.759841
In [ ]:
from sklearn.metrics import confusion_matrix
In [ ]:
y_pred_RF_test = [row['prediction'] for row in transformed_test_RF.select("prediction").collect()]
y_orig_RF_test = [row['subreddit_ix'] for row in transformed_test_RF.select("subreddit_ix").collect()]
y_pred_RF_hp_test = [row['prediction'] for row in transformed_test_RF_withHyperparams.select("prediction").collect()]
y_orig_RF_hp_test = [row['subreddit_ix'] for row in transformed_test_RF_withHyperparams.select("subreddit_ix").collect()]
                                                                                
In [ ]:
cm_RF = confusion_matrix(y_orig_RF_test, y_pred_RF_test)
cm_RF_hp = confusion_matrix(y_orig_RF_hp_test, y_pred_RF_hp_test)
In [ ]:
import matplotlib.pyplot as plt
import seaborn as sns
In [ ]:
# Plot using heatmap

plt.figure(figsize=(10, 7))
sns.heatmap(cm_RF, annot=True, fmt="d", cmap=sns.color_palette("YlOrBr", as_cmap=True), yticklabels=["Anime", "Movie"], xticklabels=["Anime", "Movie"])
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix for Random Forest Model with default Hyperparameters')
plt.savefig("../../data/plots/Heatmap1_CM_Classification_Model1.png", dpi=300)
plt.show()
In [ ]:
# Plot using heatmap

plt.figure(figsize=(10, 7))
sns.heatmap(cm_RF_hp, annot=True, fmt="d", cmap=sns.color_palette("YlOrBr", as_cmap=True), yticklabels=["Anime", "Movie"], xticklabels=["Anime", "Movie"])
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix for Random Forest Model with Hyperparameters')
plt.savefig("../../data/plots/Heatmap2_CM_Classification_Model1.png", dpi=300)
plt.show()
In [ ]:
transformed_test_RF.printSchema()
root
 |-- subreddit: string (nullable = true)
 |-- score: long (nullable = true)
 |-- num_comments: long (nullable = true)
 |-- over_18: string (nullable = true)
 |-- is_self: string (nullable = true)
 |-- is_video: string (nullable = true)
 |-- domain: string (nullable = true)
 |-- post_length: integer (nullable = true)
 |-- hour_of_day: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- day_of_month: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- year: integer (nullable = true)
 |-- has_media: string (nullable = true)
 |-- body: string (nullable = true)
 |-- features: vector (nullable = true)
 |-- subreddit_ix: double (nullable = true)
 |-- over_18_ix: double (nullable = true)
 |-- is_self_ix: double (nullable = true)
 |-- is_video_ix: double (nullable = true)
 |-- has_media_ix: double (nullable = true)
 |-- over_18_vec: vector (nullable = true)
 |-- is_self_vec: vector (nullable = true)
 |-- is_video_vec: vector (nullable = true)
 |-- has_media_vec: vector (nullable = true)
 |-- combined_features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = true)
 |-- predictedSubreddit: string (nullable = true)

In [ ]:
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType
import pandas as pd

# Function to extract the probability of the positive class
get_probability = udf(lambda v: float(v[1]), FloatType())

# Add a new column with the probability for the positive class
roc_data_RF = transformed_test_RF.withColumn("positive_probability", get_probability("probability"))
roc_data_RF_hp = transformed_test_RF_withHyperparams.withColumn("positive_probability", get_probability("probability"))

# Collect the data
roc_data_RF = roc_data_RF.select("positive_probability", "subreddit_ix").toPandas()
roc_data_RF_hp = roc_data_RF_hp.select("positive_probability", "subreddit_ix").toPandas()

# Probabilities and actual labels
y_probs_RF = roc_data_RF['positive_probability']
y_orig_RF = roc_data_RF['subreddit_ix']
y_probs_RF_hp = roc_data_RF_hp['positive_probability']
y_orig_RF_hp = roc_data_RF_hp['subreddit_ix']
                                                                                
In [ ]:
from sklearn.metrics import roc_curve, auc

# Calculate the ROC curve points
fpr, tpr, thresholds = roc_curve(y_orig_RF, y_probs_RF)

# Calculate the AUC (Area under the Curves)
roc_auc_RF = auc(fpr, tpr)

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc_RF:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve for RF Model with default hyperparameters')
plt.legend(loc="lower right")
plt.savefig("ROC1_Classification_Model1.png", dpi=300)
plt.show()
In [ ]:
from sklearn.metrics import roc_curve, auc

# Calculate the ROC curve points
fpr, tpr, thresholds = roc_curve(y_orig_RF_hp, y_probs_RF_hp)

# Calculate the AUC (Area under the Curve)
roc_auc_RF_hp = auc(fpr, tpr)

# Plottings
plt.figure(figsize=(10, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc_RF_hp:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve for RF Model with hyperparameters')
plt.legend(loc="lower right")
plt.savefig("../../data/plots/ROC2_Classification_Model1.png", dpi=300)
plt.show()
In [ ]: