Scikit-learn Integration#
In the introductory tutorial tutorial you learned how to use the components and extractors of TextDescriptives and saw how to use them for exploratory data analysis.
In this tutorial we will walk through how to use TextDescriptives in a sklearn pipeline for e.g. text classification.
Setup#
We’ll use the same dataset as in the introductory tutorial, i.e. the SMS Spam Collection Data Set. The dataset contains 5572 SMS messages categorized as ham or spam.
Load’s load the dataset and the required packages.
try:
import textdescriptives
import sklearn
except:
!pip install "textdescriptives[tutorials]"
To use the functionality of TextDescriptives in an sklearn pipeline, you simply need to instantiate TextDecriptivesFeaturizer
with the same arguments as you would provide to extract_metrics
and wrap it in a sklearn Pipeline
.
Let’s try training a classifier on the SMS data using the descriptive_stats
feature set as an example.
from textdescriptives.utils import load_sms_data
df = load_sms_data()
df.head()
label | message | |
---|---|---|
0 | ham | Go until jurong point, crazy.. Available only ... |
1 | ham | Ok lar... Joking wif u oni... |
2 | spam | Free entry in 2 a wkly comp to win FA Cup fina... |
3 | ham | U dun say so early hor... U c already then say... |
4 | ham | Nah I don't think he goes to usf, he lives aro... |
Alright, the “message” column contains the text we want to extract metrics from, and the “label” column contains the label. Now, let’s instantiate the featurizer.
from textdescriptives.integrations.sklearn_featurizer import TextDescriptivesFeaturizer
# instantiate the featurizer with the same options as you would pass
# to textdescriptives.extract_metrics
descriptive_stats_extractor = TextDescriptivesFeaturizer(
lang="en", metrics=["descriptive_stats"]
)
Time to make the pipeline. Make sure to wrap the featurizer in a ColumnTransformer, as it’s necessary to make sure the featurizer only operates on the “message” column, which is the column containing the text in this example.
As there can be missing values values after extraction, we use a SimpleImputer to impute the missing values with the median.
In the end, we use a RandomForestClassifier as the classifier, divide the data into a training and a test split and train and evaluate the model.
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn import set_config
# This tells sklearn to use pandas dataframes as output which means
# it's easier to access the feature names
set_config(transform_output="pandas")
pipe = Pipeline(
[
(
"featurizer",
ColumnTransformer(
[("text_processing", descriptive_stats_extractor, "message")],
# removes the `text_processing__` prefix from feature names
verbose_feature_names_out=False,
),
),
("imputer", SimpleImputer(strategy="median")),
("classifier", RandomForestClassifier()),
]
)
# split the data into train and test
X_train, X_test, y_train, y_test = train_test_split(
df.drop("label", axis=1),
df["label"],
test_size=0.2,
random_state=42,
)
# fit the pipeline and evaluate
pipe.fit(X_train, y_train)
print("Test accuracy:", pipe.score(X_test, y_test))
Test accuracy: 0.9452914798206278
Nice! TextDescriptivesFeaturizer implements the get_features_out
method, which means the feature names are passed on in the pipeline and allows us to get informative names for e.g. feature importance.
import pandas as pd
# extract feature importances
feature_importance_mapping = list(
zip(
pipe["classifier"].feature_names_in_,
pipe.named_steps["classifier"].feature_importances_,
)
)
print("Feature importances:")
# sort by importance
df_importances = pd.DataFrame(
feature_importance_mapping, columns=["Feature", "Importance"]
).sort_values(by="Importance", ascending=False)
df_importances
Feature importances:
Feature | Importance | |
---|---|---|
12 | n_characters | 0.184745 |
0 | token_length_mean | 0.149782 |
2 | token_length_std | 0.119453 |
10 | n_unique_tokens | 0.099166 |
9 | n_tokens | 0.095011 |
5 | sentence_length_std | 0.058420 |
11 | proportion_unique_tokens | 0.051912 |
6 | syllables_per_token_mean | 0.050768 |
8 | syllables_per_token_std | 0.047507 |
4 | sentence_length_median | 0.038493 |
13 | n_sentences | 0.037130 |
3 | sentence_length_mean | 0.036220 |
1 | token_length_median | 0.028888 |
7 | syllables_per_token_median | 0.002505 |