🔨 Add jwt filter

This commit is contained in:
SebClem 2022-05-22 18:08:58 +02:00
parent 116767f27f
commit 88e4b07312
Signed by: sebclem
GPG Key ID: 5A4308F6A359EA50
23 changed files with 209 additions and 183 deletions

4
.gitignore vendored
View File

@ -26,3 +26,7 @@ src/main/resources/templates/js
src/main/resources/static/error/css
src/main/resources/static/error/js
**.log
.jpb/

View File

@ -1,61 +1,40 @@
package net.Broken.Api.Controllers;
import net.Broken.Api.Data.Login;
import net.Broken.Api.Security.Data.JwtResponse;
import net.Broken.Api.Security.Services.JwtService;
import net.Broken.DB.Entity.UserEntity;
import net.Broken.DB.Repository.UserRepository;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import java.util.Optional;
@RestController
@RequestMapping("/api/v2")
@RequestMapping("/api/v2/auth")
@CrossOrigin(origins = "*", maxAge = 3600)
public class AuthController {
private final AuthenticationManager authenticationManager;
private final UserRepository userRepository;
private final JwtService jwtService;
public AuthController(AuthenticationManager authenticationManager, UserRepository userRepository, JwtService jwtService) {
this.authenticationManager = authenticationManager;
this.userRepository = userRepository;
this.jwtService = jwtService;
}
@PostMapping("login/discord")
public String loginDiscord(@Validated @RequestBody Login login) {
@PostMapping("/discord")
public JwtResponse loginDiscord(@Validated @RequestBody Login login) {
Authentication authentication = authenticationManager.authenticate(
new UsernamePasswordAuthenticationToken(login.redirectUri(), login.code())
);
authentication.getPrincipal();
UserEntity user = (UserEntity) authentication.getPrincipal();
return "Hello User";
String jwt = jwtService.buildJwt(user);
return new JwtResponse(jwt);
}
@GetMapping("login/discord")
public String helloUsertest() {
Optional<UserEntity> user = userRepository.findById(5);
return jwtService.buildJwt(user.get());
}
@RequestMapping(
value = "/**",
method = RequestMethod.OPTIONS
)
public ResponseEntity handle() {
return new ResponseEntity(HttpStatus.OK);
}
}

View File

@ -0,0 +1,26 @@
package net.Broken.Api.Controllers;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping("/api/v2")
@CrossOrigin(origins = "*", maxAge = 3600)
public class CrossOptionController {
/**
* For cross preflight request send by axios
*/
@RequestMapping(
value = "/**",
method = RequestMethod.OPTIONS
)
public ResponseEntity handle() {
return new ResponseEntity(HttpStatus.OK);
}
}

View File

@ -0,0 +1,21 @@
package net.Broken.Api.Controllers;
import net.Broken.DB.Entity.UserEntity;
import org.springframework.security.core.Authentication;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping("/api/v2/hello")
@CrossOrigin(origins = "*", maxAge = 3600)
public class HelloController {
@GetMapping("world")
public String helloWorld(Authentication authentication){
UserEntity principal = (UserEntity) authentication.getPrincipal();
return "Hello " + principal.getName();
}
}

View File

@ -1,11 +1,9 @@
package net.Broken.Api.Security.Components;
import net.Broken.Api.Security.Data.DiscordOauthUserInfo;
import net.Broken.Api.Security.Exception.OAuthLoginFail;
import net.Broken.Api.Security.Exceptions.OAuthLoginFail;
import net.Broken.Api.Security.Services.DiscordOauthService;
import net.Broken.DB.Entity.UserEntity;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;

View File

@ -1,9 +1,12 @@
package net.Broken.Api.Security.Data;
public class AccessTokenResponse {
public String access_token;
public String token_type;
public String expires_in;
public String refresh_token;
public String scope;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
public record AccessTokenResponse(
String access_token,
String token_type,
String expires_in,
String refresh_token,
String scope
) {
}

View File

@ -1,8 +1,12 @@
package net.Broken.Api.Security.Data;
public class DiscordOauthUserInfo {
public String id;
public String username;
public String discriminator;
public String avatar;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
@JsonIgnoreProperties(ignoreUnknown = true)
public record DiscordOauthUserInfo(
String id,
String username,
String discriminator,
String avatar) {
}

View File

@ -0,0 +1,7 @@
package net.Broken.Api.Security.Data;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
public record JwtResponse(String token) {
}

View File

@ -1,54 +0,0 @@
package net.Broken.Api.Security;
import net.Broken.DB.Entity.UserEntity;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.userdetails.UserDetails;
import java.util.Collection;
public class DiscordUserPrincipal implements UserDetails {
private UserEntity user;
public DiscordUserPrincipal(UserEntity user) {
this.user = user;
}
@Override
public Collection<? extends GrantedAuthority> getAuthorities() {
return null;
}
@Override
public String getPassword() {
return null;
}
@Override
public String getUsername() {
return user.getName();
}
@Override
public boolean isAccountNonExpired() {
return true;
}
@Override
public boolean isAccountNonLocked() {
return true;
}
@Override
public boolean isCredentialsNonExpired() {
return true;
}
@Override
public boolean isEnabled() {
return true;
}
public String getDiscordId(){
return user.getJdaId();
}
}

View File

@ -1,4 +1,4 @@
package net.Broken.Api.Security.Exception;
package net.Broken.Api.Security.Exceptions;
public class OAuthLoginFail extends Exception{
}

View File

@ -0,0 +1,44 @@
package net.Broken.Api.Security.Filters;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import net.Broken.Api.Security.Services.JwtService;
import net.Broken.DB.Entity.UserEntity;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.web.filter.OncePerRequestFilter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.ArrayList;
public class JwtFilter extends OncePerRequestFilter {
@Autowired
private JwtService jwtService;
private final Logger logger = LogManager.getLogger();
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
String authHeader = request.getHeader("Authorization");
if (authHeader != null && authHeader.startsWith("Bearer ")) {
String token = authHeader.replace("Bearer ", "");
try {
Jws<Claims> jwt = jwtService.verifyAndParseJwt(token);
UserEntity user = jwtService.getUserWithJwt(jwt);
UsernamePasswordAuthenticationToken authenticationToken = new UsernamePasswordAuthenticationToken(user, null, new ArrayList<>());
authenticationToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(request));
SecurityContextHolder.getContext().setAuthentication(authenticationToken);
} catch (Exception e) {
logger.warn("[JWT] Cannot set user authentication: " + e);
}
}
filterChain.doFilter(request, response);
}
}

View File

@ -1,6 +1,7 @@
package net.Broken.Api.Security;
import net.Broken.Api.Security.Components.UnauthorizedHandler;
import net.Broken.Api.Security.Filters.JwtFilter;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.authentication.AuthenticationManager;
@ -8,6 +9,7 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
@EnableWebSecurity
@Configuration
@ -23,16 +25,22 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter {
.exceptionHandling().authenticationEntryPoint(unauthorizedHandler).and()
.sessionManagement().sessionCreationPolicy(SessionCreationPolicy.STATELESS).and()
.authorizeRequests()
// Our private endpoints
.antMatchers("/api/v2/**").permitAll()
.anyRequest().permitAll();
// http.authenticationProvider(discordAuthenticationProvider);
.antMatchers("/api/v2/auth/**").permitAll()
.anyRequest().authenticated();
http.addFilterBefore(jwtFilter(), UsernamePasswordAuthenticationFilter.class);
// http.exceptionHandling().authenticationEntryPoint((request, response, authException) -> {
// response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
// });
}
@Bean
public JwtFilter jwtFilter(){
return new JwtFilter();
}
@Bean
@Override
public AuthenticationManager authenticationManagerBean() throws Exception {

View File

@ -1,9 +1,9 @@
package net.Broken.Api.Security.Services;
import com.google.gson.Gson;
import com.fasterxml.jackson.databind.ObjectMapper;
import net.Broken.Api.Security.Data.AccessTokenResponse;
import net.Broken.Api.Security.Data.DiscordOauthUserInfo;
import net.Broken.Api.Security.Exception.OAuthLoginFail;
import net.Broken.Api.Security.Exceptions.OAuthLoginFail;
import net.Broken.DB.Entity.UserEntity;
import net.Broken.DB.Repository.UserRepository;
import org.apache.logging.log4j.LogManager;
@ -18,6 +18,7 @@ import java.net.URLEncoder;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
@ -60,9 +61,9 @@ public class DiscordOauthService {
logger.warn("[OAUTH] Invalid response while getting AccessToken: Status Code: " + response.statusCode() + " Body:" + response.body());
throw new OAuthLoginFail();
}
Gson gson = new Gson();
AccessTokenResponse accessTokenResponse = gson.fromJson(response.body(), AccessTokenResponse.class);
return accessTokenResponse.access_token;
ObjectMapper objectMapper = new ObjectMapper();
AccessTokenResponse accessTokenResponse = objectMapper.readValue(response.body(), AccessTokenResponse.class);
return accessTokenResponse.access_token();
} catch (IOException | InterruptedException e) {
logger.catching(e);
throw new OAuthLoginFail();
@ -83,8 +84,8 @@ public class DiscordOauthService {
logger.warn("[OAUTH] Invalid response while getting UserInfo: Status Code: " + response.statusCode() + " Body:" + response.body());
throw new OAuthLoginFail();
}
Gson gson = new Gson();
return gson.fromJson(response.body(), DiscordOauthUserInfo.class);
ObjectMapper mapper = new ObjectMapper();
return mapper.readValue(response.body(), DiscordOauthUserInfo.class);
} catch (IOException | InterruptedException e) {
logger.catching(e);
throw new OAuthLoginFail();
@ -98,7 +99,7 @@ public class DiscordOauthService {
try {
HttpResponse<String> response = makeFormPost(this.tokenRevokeEndpoint, data);
if (response.statusCode() != 200) {
logger.warn("OAUTH] Invalid response while token revocation: Status Code: " + response.statusCode() + " Body:" + response.body());
logger.warn("[OAUTH] Invalid response while token revocation: Status Code: " + response.statusCode() + " Body:" + response.body());
}
} catch (IOException | InterruptedException e) {
logger.catching(e);
@ -108,8 +109,8 @@ public class DiscordOauthService {
public UserEntity loginOrRegisterDiscordUser(DiscordOauthUserInfo discordOauthUserInfo) {
return userRepository
.findByJdaId(discordOauthUserInfo.id)
.orElseGet(() -> userRepository.save(new UserEntity(discordOauthUserInfo.username, discordOauthUserInfo.id)));
.findByDiscordId(discordOauthUserInfo.id())
.orElseGet(() -> userRepository.save(new UserEntity(discordOauthUserInfo.username(), discordOauthUserInfo.id())));
}
private String getFormString(HashMap<String, String> params) throws UnsupportedEncodingException {
@ -120,9 +121,9 @@ public class DiscordOauthService {
first = false;
else
result.append("&");
result.append(URLEncoder.encode(entry.getKey(), "UTF-8"));
result.append(URLEncoder.encode(entry.getKey(), StandardCharsets.UTF_8));
result.append("=");
result.append(URLEncoder.encode(entry.getValue(), "UTF-8"));
result.append(URLEncoder.encode(entry.getValue(), StandardCharsets.UTF_8));
}
return result.toString();
}

View File

@ -1,26 +0,0 @@
package net.Broken.Api.Security.Services;
import net.Broken.Api.Security.DiscordUserPrincipal;
import net.Broken.DB.Repository.UserRepository;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.stereotype.Service;
@Service
public class DiscordUserDetailsService implements UserDetailsService {
private final UserRepository userRepository;
public DiscordUserDetailsService(UserRepository userRepository) {
this.userRepository = userRepository;
}
@Override
public UserDetails loadUserByUsername(String username) throws UsernameNotFoundException {
return new DiscordUserPrincipal(
userRepository.findByJdaId(username)
.orElseThrow(() -> new UsernameNotFoundException(username))
);
}
}

View File

@ -1,17 +1,18 @@
package net.Broken.Api.Security.Services;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.*;
import io.jsonwebtoken.security.Keys;
import net.Broken.DB.Entity.UserEntity;
import net.Broken.DB.Repository.UserRepository;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.security.Key;
import java.time.LocalDateTime;
import java.util.Calendar;
import java.util.Date;
import java.util.NoSuchElementException;
import java.util.UUID;
@Service
public class JwtService {
@ -20,7 +21,10 @@ public class JwtService {
private final Key jwtKey;
public JwtService() {
private final UserRepository userRepository;
public JwtService(UserRepository userRepository) {
this.userRepository = userRepository;
this.jwtKey = Keys.secretKeyFor(SignatureAlgorithm.HS256);
}
@ -30,11 +34,13 @@ public class JwtService {
Calendar expCal = Calendar.getInstance();
expCal.add(Calendar.DATE, 7);
Date exp = expCal.getTime();
UUID uuid = UUID.randomUUID();
return Jwts.builder()
.setSubject(user.getName())
.setId(user.getJdaId())
.claim("discord_id", user.getDiscordId())
.setId(uuid.toString())
.setIssuedAt(iat)
.setNotBefore(nbf)
.setExpiration(exp)
@ -43,4 +49,19 @@ public class JwtService {
}
public Jws<Claims> verifyAndParseJwt(String token) {
return Jwts.parserBuilder()
.setSigningKey(this.jwtKey)
.build()
.parseClaimsJws(token);
}
public UserEntity getUserWithJwt(Jws<Claims> jwt) throws NoSuchElementException {
String discordId = jwt.getBody().get("discord_id", String.class);
return userRepository.findByDiscordId(discordId)
.orElseThrow();
}
}

View File

@ -1,9 +1,7 @@
package net.Broken.DB.Entity;
import com.fasterxml.jackson.annotation.JsonIgnore;
import net.Broken.Tools.UserManager.UserUtils;
import net.dv8tion.jda.api.entities.User;
import org.springframework.security.crypto.password.PasswordEncoder;
import javax.persistence.*;
import java.util.ArrayList;
@ -22,7 +20,7 @@ public class UserEntity {
private String name;
@Column(unique=true)
private String jdaId;
private String discordId;
private boolean isBotAdmin = false;
@ -39,12 +37,12 @@ public class UserEntity {
public UserEntity(User user) {
this.name = user.getName();
this.jdaId = user.getId();
this.discordId = user.getId();
}
public UserEntity(String name, String id) {
this.name = name;
this.jdaId = id;
this.discordId = id;
}
@ -64,12 +62,12 @@ public class UserEntity {
this.name = name;
}
public String getJdaId() {
return jdaId;
public String getDiscordId() {
return discordId;
}
public void setJdaId(String jdaId) {
this.jdaId = jdaId;
public void setDiscordId(String discordId) {
this.discordId = discordId;
}
public List<PlaylistEntity> getPlaylists() {

View File

@ -13,5 +13,5 @@ import java.util.Optional;
public interface UserRepository extends CrudRepository<UserEntity, Integer> {
List<UserEntity> findByName(String name);
Optional<UserEntity> findByJdaId(String jdaId);
Optional<UserEntity> findByDiscordId(String discordId);
}

View File

@ -152,7 +152,7 @@ public class MusicWebAPIController {
UserStatsUtils.getINSTANCE().addApiCount(user, guildId);
return ApiCommandLoader.apiCommands.get(data.command).action(data, MainBot.jda.getUserById(user.getJdaId()), guild);
return ApiCommandLoader.apiCommands.get(data.command).action(data, MainBot.jda.getUserById(user.getDiscordId()), guild);
} else
return new ResponseEntity<>(new CommandResponseData(data.command, "Unknown Command", "command"), HttpStatus.BAD_REQUEST);

View File

@ -20,10 +20,10 @@ public class CacheTools {
}
public static User getJdaUser(UserEntity userEntity) {
User user = MainBot.jda.getUserById(userEntity.getJdaId());
User user = MainBot.jda.getUserById(userEntity.getDiscordId());
if (user == null) {
logger.debug("User cache not found for " + userEntity.getName() + ", fetching user.");
user = MainBot.jda.retrieveUserById(userEntity.getJdaId()).complete();
user = MainBot.jda.retrieveUserById(userEntity.getDiscordId()).complete();
}
return user;
}

View File

@ -157,7 +157,7 @@ public class SettingsUtils {
} else {
try {
UserEntity user = UserUtils.getInstance().getUserWithApiToken(userRepository, token);
User jdaUser = MainBot.jda.getUserById(user.getJdaId());
User jdaUser = MainBot.jda.getUserById(user.getDiscordId());
Guild jdaGuild = MainBot.jda.getGuildById(guild);
if (jdaGuild == null || jdaUser == null)
return false;

View File

@ -24,7 +24,6 @@ import java.awt.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
public class UserStatsUtils {
@ -106,7 +105,7 @@ public class UserStatsUtils {
}
public List<UserStats> getUserStats(User user) {
UserEntity userEntity = userRepository.findByJdaId(user.getId())
UserEntity userEntity = userRepository.findByDiscordId(user.getId())
.orElseGet(() -> genUserEntity(user));
return getUserStats(userEntity);
@ -114,7 +113,7 @@ public class UserStatsUtils {
public UserStats getGuildUserStats(Member member) {
UserEntity userEntity = userRepository.findByJdaId(member.getUser().getId())
UserEntity userEntity = userRepository.findByDiscordId(member.getUser().getId())
.orElseGet(() -> genUserEntity(member.getUser()));
List<UserStats> userStatsList = userStatsRepository.findByUserAndGuildId(userEntity, member.getGuild().getId());
@ -181,7 +180,7 @@ public class UserStatsUtils {
List<UserStats> needCache = new ArrayList<>();
Guild guild = MainBot.jda.getGuildById(guildId);
for (UserStats stats : allStats) {
Member member = guild.getMemberById(stats.getUser().getJdaId());
Member member = guild.getMemberById(stats.getUser().getDiscordId());
if (member == null) {
needCache.add(stats);
continue;
@ -197,7 +196,7 @@ public class UserStatsUtils {
logger.info("Cache mismatch, loading all guild");
MainBot.jda.getGuildById(guildId).loadMembers().get();
for (UserStats stats : needCache) {
Member member = guild.getMemberById(stats.getUser().getJdaId());
Member member = guild.getMemberById(stats.getUser().getDiscordId());
if (member == null) {
logger.warn("Can't find member '" + stats.getUser().getName() + "'after load, User leave the guild ?");
continue;

View File

@ -55,7 +55,7 @@ public class PlaylistManager {
UserEntity user = userUtils.getUserWithApiToken(userRepository, token);
PlaylistEntity playlist = getPlaylist(data.playlistId);
User jdaUser = MainBot.jda.getUserById(user.getJdaId());
User jdaUser = MainBot.jda.getUserById(user.getDiscordId());
WebLoadUtils webLoadUtils = new WebLoadUtils(data, jdaUser, MainBot.jda.getGuilds().get(0), false);
webLoadUtils.getResponse();

View File

@ -545,23 +545,16 @@ databaseChangeLog:
columnName: welcome_message
tableName: guild_preference_entity
databaseChangeLog:
- changeSet:
id: 1653086152139-1
author: seb65 (generated)
preConditions:
onFail: MARK_RAN
tableExists:
tableName: hibernate_sequence
- changeSet:
id: sebclem-manual-1
author: sebclem
changes:
- dropTable:
tableName: hibernate_sequence
- createSequence:
cycle: false
ordered: true
sequenceName: hibernate_sequence
startValue: 23608
- renameColumn:
newColumnName: discord_id
oldColumnName: jda_id
tableName: user_entity
columnDataType: varchar(255)
databaseChangeLog: []