Skip to content

Commit 257fd57

Browse files
authored
Added custom stop token id support (#14344)
1 parent e120e61 commit 257fd57

File tree

14 files changed

+212
-47
lines changed

14 files changed

+212
-47
lines changed

src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ private[johnsnowlabs] class LLAMA2(
7979
*/
8080
def encode(sentences: Seq[Annotation]): Seq[Array[Int]] = {
8181
sentences.map(s => {
82-
val sentWithTask = s.result
83-
spp.getSppModel.encodeAsIds(sentWithTask)
82+
val sentWithTask = "_" + s.result
83+
Array(bosTokenId) ++ spp.getSppModel.encodeAsIds(sentWithTask)
8484
})
8585
}
8686

@@ -97,7 +97,8 @@ private[johnsnowlabs] class LLAMA2(
9797
randomSeed: Option[Long],
9898
ignoreTokenIds: Array[Int] = Array(),
9999
beamSize: Int,
100-
maxInputLength: Int): Array[Array[Int]] = {
100+
maxInputLength: Int,
101+
stopTokenIds: Array[Int]): Array[Array[Int]] = {
101102
val ignoreTokenIdsInt = ignoreTokenIds
102103
val expandedDecoderInputsVals = batch
103104
val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
@@ -165,7 +166,8 @@ private[johnsnowlabs] class LLAMA2(
165166
ignoreTokenIdsInt,
166167
session,
167168
applySoftmax = true,
168-
ovInferRequest = ovInferRequest)
169+
ovInferRequest = ovInferRequest,
170+
stopTokenIds = stopTokenIds)
169171

170172
modelOutputs
171173
}
@@ -184,7 +186,8 @@ private[johnsnowlabs] class LLAMA2(
184186
randomSeed: Option[Long] = None,
185187
ignoreTokenIds: Array[Int] = Array(),
186188
beamSize: Int,
187-
maxInputLength: Int): Seq[Annotation] = {
189+
maxInputLength: Int,
190+
stopTokenIds: Array[Int]): Seq[Annotation] = {
188191

189192
val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
190193
val batchSP = encode(batch)
@@ -201,7 +204,8 @@ private[johnsnowlabs] class LLAMA2(
201204
randomSeed,
202205
ignoreTokenIds,
203206
beamSize,
204-
maxInputLength)
207+
maxInputLength,
208+
stopTokenIds)
205209

206210
decode(spIds)
207211

src/main/scala/com/johnsnowlabs/ml/ai/Mistral.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ private[johnsnowlabs] class Mistral(
7878
*/
7979
def encode(sentences: Seq[Annotation]): Seq[Array[Int]] = {
8080
sentences.map(s => {
81-
val sentWithTask = s.result
82-
spp.getSppModel.encodeAsIds(sentWithTask)
81+
val sentWithTask = "_" + s.result
82+
Array(bosTokenId) ++ spp.getSppModel.encodeAsIds(sentWithTask)
8383
})
8484
}
8585

@@ -96,7 +96,8 @@ private[johnsnowlabs] class Mistral(
9696
randomSeed: Option[Long],
9797
ignoreTokenIds: Array[Int] = Array(),
9898
beamSize: Int,
99-
maxInputLength: Int): Array[Array[Int]] = {
99+
maxInputLength: Int,
100+
stopTokenIds: Array[Int] = Array()): Array[Array[Int]] = {
100101
val ignoreTokenIdsInt = ignoreTokenIds
101102
val expandedDecoderInputsVals = batch
102103
val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
@@ -162,8 +163,9 @@ private[johnsnowlabs] class Mistral(
162163
randomSeed,
163164
ignoreTokenIdsInt,
164165
session,
165-
applySoftmax = false,
166-
ovInferRequest = ovInferRequest)
166+
applySoftmax = true,
167+
ovInferRequest = ovInferRequest,
168+
stopTokenIds = stopTokenIds)
167169

168170
// decoderOutputs
169171
modelOutputs
@@ -183,7 +185,8 @@ private[johnsnowlabs] class Mistral(
183185
randomSeed: Option[Long] = None,
184186
ignoreTokenIds: Array[Int] = Array(),
185187
beamSize: Int,
186-
maxInputLength: Int): Seq[Annotation] = {
188+
maxInputLength: Int,
189+
stopTokenIds: Array[Int]): Seq[Annotation] = {
187190

188191
val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
189192
val batchSP = encode(batch)
@@ -200,7 +203,8 @@ private[johnsnowlabs] class Mistral(
200203
randomSeed,
201204
ignoreTokenIds,
202205
beamSize,
203-
maxInputLength)
206+
maxInputLength,
207+
stopTokenIds)
204208

205209
decode(spIds)
206210

src/main/scala/com/johnsnowlabs/ml/ai/Phi2.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ private[johnsnowlabs] class Phi2(
103103
randomSeed: Option[Long],
104104
ignoreTokenIds: Array[Int] = Array(),
105105
beamSize: Int,
106-
maxInputLength: Int): Array[Array[Int]] = {
106+
maxInputLength: Int,
107+
stopTokenIds: Array[Int]): Array[Array[Int]] = {
107108
val ignoreTokenIdsInt = ignoreTokenIds
108109
val expandedDecoderInputsVals = batch
109110
val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
@@ -169,7 +170,8 @@ private[johnsnowlabs] class Phi2(
169170
ignoreTokenIdsInt,
170171
session,
171172
applySoftmax = false,
172-
ovInferRequest = ovInferRequest)
173+
ovInferRequest = ovInferRequest,
174+
stopTokenIds = stopTokenIds)
173175

174176
// decoderOutputs
175177
modelOutputs
@@ -189,7 +191,8 @@ private[johnsnowlabs] class Phi2(
189191
randomSeed: Option[Long] = None,
190192
ignoreTokenIds: Array[Int] = Array(),
191193
beamSize: Int,
192-
maxInputLength: Int): Seq[Annotation] = {
194+
maxInputLength: Int,
195+
stopTokenIds: Array[Int]): Seq[Annotation] = {
193196

194197
val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
195198
val batchSP = encode(batch)
@@ -206,7 +209,8 @@ private[johnsnowlabs] class Phi2(
206209
randomSeed,
207210
ignoreTokenIds,
208211
beamSize,
209-
maxInputLength)
212+
maxInputLength,
213+
stopTokenIds)
210214

211215
decode(spIds)
212216

src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ trait Generate {
104104
ignoreTokenIds: Array[Int] = Array(),
105105
session: Either[Session, (OrtEnvironment, OrtSession)],
106106
applySoftmax: Boolean = true,
107-
ovInferRequest: Option[InferRequest] = None): Array[Array[Int]] = {
107+
ovInferRequest: Option[InferRequest] = None,
108+
stopTokenIds: Array[Int] = Array()): Array[Array[Int]] = {
108109

109110
// TODO: Add support for ignoreTokenIds
110111

@@ -117,8 +118,8 @@ trait Generate {
117118
noRepeatNgramSize = noRepeatNgramSize,
118119
vocabSize = vocabSize))
119120

120-
logitProcessorList.addProcess(
121-
new MinLengthLogitProcessor(eosTokenId, minOutputLength, vocabSize))
121+
// logitProcessorList.addProcess(
122+
// new MinLengthLogitProcessor(eosTokenId, minOutputLength, vocabSize))
122123

123124
logitProcessorList.addProcess(new TemperatureLogitWarper(temperature))
124125

@@ -148,7 +149,8 @@ trait Generate {
148149
randomSeed,
149150
session,
150151
applySoftmax,
151-
ovInferRequest)
152+
ovInferRequest,
153+
stopTokenIds)
152154
}
153155

154156
/** Beam Search for text generation
@@ -193,7 +195,8 @@ trait Generate {
193195
randomSeed: Option[Long],
194196
session: Either[Session, (OrtEnvironment, OrtSession)],
195197
applySoftmax: Boolean,
196-
ovInferRequest: Option[InferRequest] = None): Array[Array[Int]] = {
198+
ovInferRequest: Option[InferRequest] = None,
199+
stopTokenIds: Array[Int] = Array()): Array[Array[Int]] = {
197200
val inputIds = inputIdsVal
198201
val batchSize = beamScorer.getBeamHypothesesSeq.length
199202
val numBeams = beamScorer.getNumBeams
@@ -227,21 +230,22 @@ trait Generate {
227230
// Optionally Apply log softmax to model outputs
228231
var nextTokenScores =
229232
if (applySoftmax) nextTokenLogits.map(logSoftmax) else nextTokenLogits
230-
231233
// Process the logits by defined logit processors
232234
val nextTokenScoresProcessed =
233235
logitProcessor.process(expandedInputs, nextTokenScores, currentLength)
234236

237+
// Process the logits by defined logit warpers
238+
if (doSample) {
239+
nextTokenScores =
240+
logitProcessor.warp(expandedInputs, nextTokenScoresProcessed, currentLength)
241+
}
235242
// Add previous beam scores to the output
236-
nextTokenScores = nextTokenScoresProcessed.zipWithIndex.map { case (x, ind1) =>
243+
nextTokenScores = nextTokenScores.zipWithIndex.map { case (x, ind1) =>
237244
x.zipWithIndex.map { case (y, _) =>
238245
y + beamScores(ind1)
239246
}
240247
}
241-
// Process the logits by defined logit warpers
242-
if (doSample) {
243-
nextTokenScores = logitProcessor.warp(expandedInputs, nextTokenScores, currentLength)
244-
}
248+
245249
// Reshape next token score to (batchSize, vocabSize * numBeams)
246250
val vocabSize = nextTokenScores.head.length
247251
val reshapedNextTokenScores =
@@ -290,7 +294,8 @@ trait Generate {
290294
padTokenId,
291295
eosTokenId,
292296
beamIndices,
293-
currentLength)
297+
currentLength,
298+
stopTokenIds)
294299
val newBeamScores = beamOutputs._1.flatMap(_.toList)
295300
val beamNextTokens = beamOutputs._2.flatMap(_.toList)
296301
val beamIdx = beamOutputs._3.flatMap(_.toList)

src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopKLogitWarper.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,13 @@ class TopKLogitWarper(
4343
}
4444

4545
private def getTopKIndices(logits: Array[Float], k: Int): Array[Int] = {
46-
logits.indices.sortBy(logits(_)).reverse.take(k).toArray
46+
// ignore float.NegativeInfinity values
47+
val topKIndices = new ArrayBuffer[Int]()
48+
val sortedLogits = logits.zipWithIndex.filter(_._1 != filterValue).sortBy(-_._1)
49+
for ((_, i) <- sortedLogits.take(k)) {
50+
topKIndices += i
51+
}
52+
topKIndices.toArray
4753
}
4854

4955
private def maskNotTopKValues(logits: Array[Float], topKIndices: Array[Int]): Array[Float] = {

src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopPLogitWarper.scala

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,40 @@ class TopPLogitWarper(val p: Double, val minTokensToKeep: Int = 1) extends Logit
2424
val logitsUpd = scores.map(_.clone()) // Deep copy of the scores
2525

2626
if (p < 1.0) {
27-
val scoresFiltered = scores.map(_.filterNot(_.isInfinite)) // Filter out infinite values
28-
val scoresShape = Array(scoresFiltered.length, scoresFiltered.head.length)
29-
val topPThreshold = math.ceil(p * scoresShape.last).toInt // Determine top-p threshold
27+
val scoresFiltered = scores // Filter out infinite values
28+
val scoresSoftmaxed = scoresFiltered.map(softmax) // Softmax the scores
3029

31-
for ((logits, i) <- scores.zipWithIndex) {
32-
val topPIndices = getTopPIndices(logits, topPThreshold)
33-
val maskedValues = maskNotTopPValues(logits, topPIndices)
30+
for ((logits, i) <- scoresSoftmaxed.zipWithIndex) {
31+
val topPIndices = getTopPIndices(logits, p)
32+
// Mask the values that are not in the top-p
33+
val maskedValues = maskNotTopPValues(logitsUpd(i), topPIndices)
3434
logitsUpd(i) = maskedValues
3535
}
3636
}
3737

3838
logitsUpd
3939
}
4040

41-
private def getTopPIndices(logits: Array[Float], k: Int): Array[Int] = {
42-
logits.zipWithIndex.sortBy(-_._1).take(k).map(_._2)
41+
private def getTopPIndices(logits: Array[Float], p: Double): Array[Int] = {
42+
// sort the logits in descending order
43+
var sortedLogits = logits.zipWithIndex.sortBy(-_._1)
44+
45+
// filter out the negative infinity values
46+
sortedLogits = sortedLogits.filter(_._1 > 0.0)
47+
48+
// cumulative sum of the probabilities
49+
val cumSum = sortedLogits.map(_._1).scanLeft(0.0)(_ + _)
50+
51+
// find the index of the last element that is less than p
52+
val lastIdx = cumSum.indexWhere(_ >= p)
53+
// if the last index is less than the minimum tokens to keep, return the top p tokens
54+
55+
if (lastIdx < minTokensToKeep) {
56+
sortedLogits.take(math.ceil(p * logits.length).toInt).map(_._2)
57+
} else {
58+
sortedLogits.take(lastIdx).map(_._2)
59+
}
60+
4361
}
4462

4563
private def maskNotTopPValues(logits: Array[Float], topPIndices: Array[Int]): Array[Float] = {

src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Search/BeamScorer.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ abstract class BeamScorer() {
2626
padTokenId: Int,
2727
eosTokenId: Int,
2828
beamIndices: Seq[Array[Int]],
29-
currentLength: Int): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]])
29+
currentLength: Int,
30+
stopTokenIds: Array[Int]): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]])
3031

3132
def finalize(
3233
inputIds: Seq[Array[Int]],
@@ -40,4 +41,5 @@ abstract class BeamScorer() {
4041
def getBeamHypothesesSeq: Seq[BeamHypotheses]
4142
def getNumBeams: Int
4243
def isDone: Boolean
44+
def getDone: Array[Boolean]
4345
}

src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Search/BeamSearchScorer.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class BeamSearchScorer(
4343
override def getNumBeams: Int = numBeams
4444
private val done: Array[Boolean] = Array.fill(batchSize)(false)
4545

46+
override def getDone: Array[Boolean] = done
47+
4648
override def process(
4749
inputIds: Seq[Array[Int]],
4850
nextScores: Seq[Array[Float]],
@@ -51,7 +53,8 @@ class BeamSearchScorer(
5153
padTokenId: Int,
5254
eosTokenId: Int,
5355
beamIndices: Seq[Array[Int]],
54-
currentLength: Int): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]]) = {
56+
currentLength: Int,
57+
stopTokenIds: Array[Int]): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]]) = {
5558
// val currentLength = inputIds.length
5659
val batchSize = this.beamHypothesesSeq.length
5760
val nextBeamScores = Array.ofDim[Float](batchSize, this.beamSize)
@@ -75,7 +78,8 @@ class BeamSearchScorer(
7578
val nextIndex = nextIndices(batchIdx)(beamTokenRank)
7679
val batchBeamIdx = batchIdx * this.beamSize + nextIndex
7780

78-
if (eosTokenId == nextToken) {
81+
// either eos token or stop tokens are found
82+
if (eosTokenId == nextToken || stopTokenIds.contains(nextToken)) {
7983
if (beamTokenRank >= this.beamSize) {
8084
break
8185
}

src/main/scala/com/johnsnowlabs/nlp/HasGeneratorProperties.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,4 +222,19 @@ trait HasGeneratorProperties {
222222

223223
/** @group getParam */
224224
def getNReturnSequences: Int = $(nReturnSequences)
225+
226+
/** Stop tokens to terminate the generation
227+
*
228+
* @group param
229+
*/
230+
var stopTokenIds =
231+
new IntArrayParam(this, "stopTokens", "Stop tokens to terminate the generation")
232+
233+
/** @group setParam */
234+
def setStopTokenIds(value: Array[Int]): this.type = {
235+
set(stopTokenIds, value)
236+
}
237+
238+
/** @group getParam */
239+
def getStopTokenIds: Array[Int] = $(stopTokenIds)
225240
}

src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@ class LLAMA2Transformer(override val uid: String)
235235
ignoreTokenIds -> Array(),
236236
batchSize -> 1,
237237
beamSize -> 1,
238-
maxInputLength -> 4096)
238+
maxInputLength -> 4096,
239+
stopTokenIds -> Array())
239240

240241
/** takes a document and annotations and produces new annotations of this annotator's annotation
241242
* type
@@ -269,7 +270,8 @@ class LLAMA2Transformer(override val uid: String)
269270
randomSeed = this.randomSeed,
270271
ignoreTokenIds = $(ignoreTokenIds),
271272
beamSize = $(beamSize),
272-
maxInputLength = $(maxInputLength))
273+
maxInputLength = $(maxInputLength),
274+
stopTokenIds = $(stopTokenIds))
273275
} else {
274276
Seq()
275277
}

0 commit comments

Comments
 (0)