Skip to content
Permalink
07c3e15e7c
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
45 lines (36 sloc) 1.83 KB
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from tqdm import tqdm
import torch
import numpy as np
def _get_sentiment(sentiment_model, tokenizer, device, dialogs):
scores = []
for dialog in tqdm(dialogs, desc="Scoring sentiment"):
utterances = [utterance[1] for utterance in dialog]
inputs = tokenizer(utterances,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512)
inputs = inputs.to(device)
class_weights = torch.tensor([-1., 0., 1.]).to(device)
with torch.no_grad():
logits = sentiment_model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)
#Convert polarity classes (negative, positive) to score in (-1, 1)
polarity_scores = torch.matmul(probs, class_weights)
scores.append(polarity_scores.to("cpu").numpy().tolist())
return scores
def get_sentiment(sentiment_modelpath, dialogs):
tokenizer = AutoTokenizer.from_pretrained(sentiment_modelpath)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_modelpath)
sentiment_model.to(device)
return _get_sentiment(sentiment_model, tokenizer, device, dialogs)
def get_sentiment2(text, batch_size=32, sentiment_modelpath="cardiffnlp/twitter-roberta-base-sentiment"):
batches = []
for start_index in range(0, len(text), batch_size):
batch = text[start_index : start_index + batch_size]
batches.append([(None, t) for t in batch])
scores = get_sentiment(sentiment_modelpath, batches)
scores = np.concatenate(scores)
return scores