Java手搓简易Transfomer - Tokenizer部分
本文内容较新
·
今天更新
最后更新: 2026年03月18日
预计阅读时间: 10.8 分钟
2703 字
250 字/分
搭建一个简易的词表, 把每个符号转换为token, 因为这样最简单了😭
public Tokenizer(String dirName){
this.dirName = dirName;
this.filePath = Paths.get(dirName, "token_id.json").toString();
this.tokenToId = new HashMap<>();
this.nextId = 1; // 0 作为保留 ID 留给未知符号
initTokenFile();
}
private void initTokenFile(){ // 先去找有没有已存在的词表Json文件
if(!Files.exists(Paths.get(dirName))){
try { // 如果不存在, 创建一个新的词表
Files.createDirectories(Paths.get(dirName));
} catch (IOException e) {
e.printStackTrace();
}
}
File file = new File(filePath);
if(file.exists()){
try { //获取信息
String content = new String(Files.readAllBytes(Paths.get(filePath)));
JSONObject json = new JSONObject(content);
// 遍历, 获得字符与token的键对值
for(String key: json.keySet()){
int id = json.getInt(key);
tokenToId.put(key,id);
if(id >= nextId) nextId = id + 1;
}
} catch (Exception e) {
e.printStackTrace();
createEmptyFile();
}
}else{
createEmptyFile();
}
}
private void createEmptyFile(){
tokenToId.clear();
tokenToId.put("<UNK>", 0); // 把 0 作为保留id
nextId = 1; //从 1 开始计数
saveToFile();
}
private void saveToFile(){
try(FileWriter writer = new FileWriter(filePath)){
JSONObject json = new JSONObject(tokenToId);
writer.write(json.toString());
} catch (IOException e){
e.printStackTrace();
}
}接下来做字符转TokenId部分
public int getTokenId(String text, boolean isTraining){
if(tokenToId.containsKey(text)){
return tokenToId.get(text); // 若存在 直接返回
} else {
if (isTraining){ // 如果是训练模式,则扩大词表
int newId = nextId++;
tokenToId.put(text, newId);
saveToFile();
System.out.println("[Train] New token" + text + " added, ID: " + newId);
return newId;
} else { // 如果没在训练模式,则直接使用保留 ID 0
System.out.println("[Infer] Unknown token" + text);
return UKN_TOKEN_ID;
}
}
}重点是tokenizer和detokenizer
public List<Integer> tokenizer(String sentence, boolean isTraining){ // 输入一段字符串 在这里应该是每个单词
List<Integer> ids = new ArrayList<>(); //新建一个储存 ID 的arraylist
if(sentence == null || sentence.isEmpty()){
return ids;
}
for(char c : sentence.toCharArray()){ //for each遍历 转为ID
String token = String.valueOf(c);
ids.add(getTokenId(token, isTraining));
}
return ids;
}
public String detokenizer(List<Integer> ids){
if(ids == null || ids.isEmpty()){ // 合法性检查
return "";
}
// 新建 HashMap
Map<Integer, String> idToToken = new HashMap<>();
for(Map.Entry<String, Integer> entry: tokenToId.entrySet()){ // 遍历 HashMap 放进id和字符的对应关系
idToToken.put(entry.getValue(), entry.getKey());
}
StringBuilder sb = new StringBuilder();
for(Integer id : ids){ // 遍历 tokenId
if(id == null){
continue;
}
String token = idToToken.get(id);
if(token == null){
token = "<UNK>"; // 未知返回保留字符
}
sb.append(token); // 合并
}
return sb.toString();
}
评论 暂无