Model do wrong print response on java - ( ai.onnxruntime)

#1
by kgrabko - opened

Guys what is wrong with my extract :

You: Were is city Rome ?
We is city is ares are
You: Hi
Hi
You: cool
cool
You: How are you doing today ?
???????
???
You: tell me about you self as about big LLM
about and and and about about you asCOMRA
You: Where is Rome city ?
?????
You: what coutry ?
is? ?
You: is this Capital ?
is ist ist? ist

===============

public static String generate(String prompt, OrtEnvironment env, OrtSession encoderSession,
OrtSession decoderSession, SpProcessor sp) throws OrtException {
// Токенизируем входной текст с помощью DJL SentencePiece
int[] inputIdsInt = sp.encode(prompt);
long[] inputIds = Arrays.stream(inputIdsInt).mapToLong(i -> i).toArray();
long[] attentionMask = new long[inputIds.length];
Arrays.fill(attentionMask, 1L);
Map<String, OnnxTensor> encoderInputs = new HashMap<>();
long[][] inputIds2D = new long[][] { inputIds };
long[][] attentionMask2D = new long[][] { attentionMask };
encoderInputs.put("input_ids", OnnxTensor.createTensor(env, inputIds2D));
encoderInputs.put("attention_mask", OnnxTensor.createTensor(env, attentionMask2D));
OrtSession.Result encoderResult = encoderSession.run(encoderInputs);
float[][][] encoderHiddenState = (float[][][]) ((OnnxTensor) encoderResult.get("last_hidden_state").get()).getValue();
Map<String, OnnxTensor> decoderInputs = new HashMap<>();
LongBuffer buffer = LongBuffer.wrap(inputIds);
decoderInputs.put("input_ids", OnnxTensor.createTensor(env, inputIds2D));
decoderInputs.put("encoder_attention_mask", OnnxTensor.createTensor(env, attentionMask2D));
decoderInputs.put("encoder_hidden_states", OnnxTensor.createTensor(env, encoderHiddenState));
OrtSession.Result decoderResult = decoderSession.run(decoderInputs);
OnnxTensor decoderOutput = (OnnxTensor) decoderResult.get("logits").get();
long[] shape = decoderOutput.getInfo().getShape();
float[][][] logits = (float[][][]) decoderOutput.getValue();
int batchSize = (int) shape[0];
int sequenceLength = (int) shape[1];
int vocabSize = (int) shape[2];
int[] generatedTokenIds = new int[sequenceLength];
for (int i = 0; i < sequenceLength; i++) {
float[] tokenLogits = logits[0][i]; // logits for token i
int maxIndex = 0;
float maxValue = tokenLogits[0];
for (int j = 1; j < vocabSize; j++) {
if (tokenLogits[j] > maxValue) {
maxValue = tokenLogits[j];
maxIndex = j;
}
}
generatedTokenIds[i] = maxIndex;
}
return sp.decode(generatedTokenIds);
}

  • T5-small with use_cache=True . Bad export onnx model

🚫 Past key/value caching
Without this, the model recalculates everything from scratch each time. This is inefficient and hinders autoregression.
🚫 Correct cross-attention
If the model was not exported with support for and , it may ignore context.
🚫 Positional embeddings
The ONNX model may not take into account the position of tokens if they are not explicitly specified, which leads to repetition of the same word.

What's going on
• The model accepts and in the encoder—that's correct.
• The decoder accepts and—that's also correct.
• But the results—"France??", "Spanish. German. I? you?", "why fox jumps…"—suggest that the decoder isn't taking context into account and is simply repeating high-frequency tokens.
This is typical behavior if the model was exported without support and without properly configured cross-attention.

Sign up or log in to comment